فهرست منبع

Merge branch 'dev' into multiple_loaded_models

AUTOMATIC1111 2 سال پیش
والد
کامیت
22ecb78b51
45فایلهای تغییر یافته به همراه924 افزوده شده و 476 حذف شده
  1. 1 1
      extensions-builtin/Lora/ui_edit_user_metadata.py
  2. 5 5
      javascript/localization.js
  3. 2 0
      modules/cmd_args.py
  4. 75 8
      modules/devices.py
  5. 50 0
      modules/errors.py
  6. 7 3
      modules/extensions.py
  7. 19 0
      modules/extra_networks.py
  8. 3 0
      modules/generation_parameters_copypaste.py
  9. 60 0
      modules/gradio_extensons.py
  10. 2 3
      modules/hypernetworks/hypernetwork.py
  11. 1 1
      modules/images.py
  12. 2 4
      modules/img2img.py
  13. 71 40
      modules/processing.py
  14. 7 2
      modules/prompt_parser.py
  15. 102 0
      modules/rng_philox.py
  16. 0 60
      modules/scripts.py
  17. 3 3
      modules/sd_hijack.py
  18. 2 0
      modules/sd_hijack_clip.py
  19. 2 2
      modules/sd_hijack_optimizations.py
  20. 17 12
      modules/sd_models.py
  21. 13 7
      modules/sd_samplers_common.py
  22. 5 4
      modules/sd_samplers_kdiffusion.py
  23. 15 1
      modules/sd_vae.py
  24. 35 18
      modules/shared.py
  25. 1 4
      modules/styles.py
  26. 3 1
      modules/textual_inversion/textual_inversion.py
  27. 2 1
      modules/txt2img.py
  28. 146 206
      modules/ui.py
  29. 1 1
      modules/ui_checkpoint_merger.py
  30. 29 5
      modules/ui_common.py
  31. 1 1
      modules/ui_components.py
  32. 15 11
      modules/ui_extensions.py
  33. 2 11
      modules/ui_extra_networks.py
  34. 4 1
      modules/ui_extra_networks_checkpoints.py
  35. 60 0
      modules/ui_extra_networks_checkpoints_user_metadata.py
  36. 1 1
      modules/ui_extra_networks_hypernets.py
  37. 1 1
      modules/ui_extra_networks_textual_inversion.py
  38. 1 1
      modules/ui_postprocessing.py
  39. 110 0
      modules/ui_prompt_styles.py
  40. 1 1
      modules/ui_settings.py
  41. 2 2
      requirements.txt
  42. 7 7
      requirements_versions.txt
  43. 3 10
      scripts/xyz_grid.py
  44. 27 2
      style.css
  45. 8 35
      webui.py

+ 1 - 1
extensions-builtin/Lora/ui_edit_user_metadata.py

@@ -167,7 +167,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
                 random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
 
             with gr.Column(scale=1, min_width=120):
-                generate_random_prompt = gr.Button('Generate').style(full_width=True, size="lg")
+                generate_random_prompt = gr.Button('Generate', size="lg", scale=1)
 
         self.edit_notes = gr.TextArea(label='Notes', lines=4)
 

+ 5 - 5
javascript/localization.js

@@ -11,11 +11,11 @@ var ignore_ids_for_localization = {
     train_hypernetwork: 'OPTION',
     txt2img_styles: 'OPTION',
     img2img_styles: 'OPTION',
-    setting_random_artist_categories: 'SPAN',
-    setting_face_restoration_model: 'SPAN',
-    setting_realesrgan_enabled_models: 'SPAN',
-    extras_upscaler_1: 'SPAN',
-    extras_upscaler_2: 'SPAN',
+    setting_random_artist_categories: 'OPTION',
+    setting_face_restoration_model: 'OPTION',
+    setting_realesrgan_enabled_models: 'OPTION',
+    extras_upscaler_1: 'OPTION',
+    extras_upscaler_2: 'OPTION',
 };
 
 var re_num = /^[.\d]+$/;

+ 2 - 0
modules/cmd_args.py

@@ -112,3 +112,5 @@ parser.add_argument('--subpath', type=str, help='customize the subpath for gradi
 parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
 parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
 parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
+parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
+parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False)

+ 75 - 8
modules/devices.py

@@ -3,7 +3,7 @@ import contextlib
 from functools import lru_cache
 
 import torch
-from modules import errors
+from modules import errors, rng_philox
 
 if sys.platform == "darwin":
     from modules import mac_specific
@@ -71,14 +71,17 @@ def enable_tf32():
         torch.backends.cudnn.allow_tf32 = True
 
 
-
 errors.run(enable_tf32, "Enabling TF32")
 
-cpu = torch.device("cpu")
-device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
-dtype = torch.float16
-dtype_vae = torch.float16
-dtype_unet = torch.float16
+cpu: torch.device = torch.device("cpu")
+device: torch.device = None
+device_interrogate: torch.device = None
+device_gfpgan: torch.device = None
+device_esrgan: torch.device = None
+device_codeformer: torch.device = None
+dtype: torch.dtype = torch.float16
+dtype_vae: torch.dtype = torch.float16
+dtype_unet: torch.dtype = torch.float16
 unet_needs_upcast = False
 
 
@@ -90,23 +93,87 @@ def cond_cast_float(input):
     return input.float() if unet_needs_upcast else input
 
 
+nv_rng = None
+
+
 def randn(seed, shape):
+    """Generate a tensor with random numbers from a normal distribution using seed.
+
+    Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
+
     from modules.shared import opts
 
-    torch.manual_seed(seed)
+    manual_seed(seed)
+
+    if opts.randn_source == "NV":
+        return torch.asarray(nv_rng.randn(shape), device=device)
+
     if opts.randn_source == "CPU" or device.type == 'mps':
         return torch.randn(shape, device=cpu).to(device)
+
     return torch.randn(shape, device=device)
 
 
+def randn_local(seed, shape):
+    """Generate a tensor with random numbers from a normal distribution using seed.
+
+    Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
+
+    from modules.shared import opts
+
+    if opts.randn_source == "NV":
+        rng = rng_philox.Generator(seed)
+        return torch.asarray(rng.randn(shape), device=device)
+
+    local_device = cpu if opts.randn_source == "CPU" or device.type == 'mps' else device
+    local_generator = torch.Generator(local_device).manual_seed(int(seed))
+    return torch.randn(shape, device=local_device, generator=local_generator).to(device)
+
+
+def randn_like(x):
+    """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
+
+    Use either randn() or manual_seed() to initialize the generator."""
+
+    from modules.shared import opts
+
+    if opts.randn_source == "NV":
+        return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
+
+    if opts.randn_source == "CPU" or x.device.type == 'mps':
+        return torch.randn_like(x, device=cpu).to(x.device)
+
+    return torch.randn_like(x)
+
+
 def randn_without_seed(shape):
+    """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
+
+    Use either randn() or manual_seed() to initialize the generator."""
+
     from modules.shared import opts
 
+    if opts.randn_source == "NV":
+        return torch.asarray(nv_rng.randn(shape), device=device)
+
     if opts.randn_source == "CPU" or device.type == 'mps':
         return torch.randn(shape, device=cpu).to(device)
+
     return torch.randn(shape, device=device)
 
 
+def manual_seed(seed):
+    """Set up a global random number generator using the specified seed."""
+    from modules.shared import opts
+
+    if opts.randn_source == "NV":
+        global nv_rng
+        nv_rng = rng_philox.Generator(seed)
+        return
+
+    torch.manual_seed(seed)
+
+
 def autocast(disable=False):
     from modules import shared
 

+ 50 - 0
modules/errors.py

@@ -84,3 +84,53 @@ def run(code, task):
         code()
     except Exception as e:
         display(task, e)
+
+
+def check_versions():
+    from packaging import version
+    from modules import shared
+
+    import torch
+    import gradio
+
+    expected_torch_version = "2.0.0"
+    expected_xformers_version = "0.0.20"
+    expected_gradio_version = "3.39.0"
+
+    if version.parse(torch.__version__) < version.parse(expected_torch_version):
+        print_error_explanation(f"""
+You are running torch {torch.__version__}.
+The program is tested to work with torch {expected_torch_version}.
+To reinstall the desired version, run with commandline flag --reinstall-torch.
+Beware that this will cause a lot of large files to be downloaded, as well as
+there are reports of issues with training tab on the latest version.
+
+Use --skip-version-check commandline argument to disable this check.
+        """.strip())
+
+    if shared.xformers_available:
+        import xformers
+
+        if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
+            print_error_explanation(f"""
+You are running xformers {xformers.__version__}.
+The program is tested to work with xformers {expected_xformers_version}.
+To reinstall the desired version, run with commandline flag --reinstall-xformers.
+
+Use --skip-version-check commandline argument to disable this check.
+            """.strip())
+
+    if gradio.__version__ != expected_gradio_version:
+        print_error_explanation(f"""
+You are running gradio {gradio.__version__}.
+The program is designed to work with gradio {expected_gradio_version}.
+Using a different version of gradio is extremely likely to break the program.
+
+Reasons why you have the mismatched gradio version can be:
+  - you use --skip-install flag.
+  - you use webui.py to start the program instead of launch.py.
+  - an extension installs the incompatible gradio version.
+
+Use --skip-version-check commandline argument to disable this check.
+        """.strip())
+

+ 7 - 3
modules/extensions.py

@@ -11,9 +11,9 @@ os.makedirs(extensions_dir, exist_ok=True)
 
 
 def active():
-    if shared.opts.disable_all_extensions == "all":
+    if shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all":
         return []
-    elif shared.opts.disable_all_extensions == "extra":
+    elif shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions == "extra":
         return [x for x in extensions if x.enabled and x.is_builtin]
     else:
         return [x for x in extensions if x.enabled]
@@ -141,8 +141,12 @@ def list_extensions():
     if not os.path.isdir(extensions_dir):
         return
 
-    if shared.opts.disable_all_extensions == "all":
+    if shared.cmd_opts.disable_all_extensions:
+        print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
+    elif shared.opts.disable_all_extensions == "all":
         print("*** \"Disable all extensions\" option was set, will not load any extensions ***")
+    elif shared.cmd_opts.disable_extra_extensions:
+        print("*** \"--disable-extra-extensions\" arg was used, will only load built-in extensions ***")
     elif shared.opts.disable_all_extensions == "extra":
         print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
 

+ 19 - 0
modules/extra_networks.py

@@ -1,3 +1,5 @@
+import json
+import os
 import re
 from collections import defaultdict
 
@@ -177,3 +179,20 @@ def parse_prompts(prompts):
 
     return res, extra_data
 
+
+def get_user_metadata(filename):
+    if filename is None:
+        return {}
+
+    basename, ext = os.path.splitext(filename)
+    metadata_filename = basename + '.json'
+
+    metadata = {}
+    try:
+        if os.path.isfile(metadata_filename):
+            with open(metadata_filename, "r", encoding="utf8") as file:
+                metadata = json.load(file)
+    except Exception as e:
+        errors.display(e, f"reading extra network user metadata from {metadata_filename}")
+
+    return metadata

+ 3 - 0
modules/generation_parameters_copypaste.py

@@ -280,6 +280,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
     if "Hires sampler" not in res:
         res["Hires sampler"] = "Use same sampler"
 
+    if "Hires checkpoint" not in res:
+        res["Hires checkpoint"] = "Use same checkpoint"
+
     if "Hires prompt" not in res:
         res["Hires prompt"] = ""
 

+ 60 - 0
modules/gradio_extensons.py

@@ -0,0 +1,60 @@
+import gradio as gr
+
+from modules import scripts
+
+def add_classes_to_gradio_component(comp):
+    """
+    this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
+    """
+
+    comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
+
+    if getattr(comp, 'multiselect', False):
+        comp.elem_classes.append('multiselect')
+
+
+def IOComponent_init(self, *args, **kwargs):
+    self.webui_tooltip = kwargs.pop('tooltip', None)
+
+    if scripts.scripts_current is not None:
+        scripts.scripts_current.before_component(self, **kwargs)
+
+    scripts.script_callbacks.before_component_callback(self, **kwargs)
+
+    res = original_IOComponent_init(self, *args, **kwargs)
+
+    add_classes_to_gradio_component(self)
+
+    scripts.script_callbacks.after_component_callback(self, **kwargs)
+
+    if scripts.scripts_current is not None:
+        scripts.scripts_current.after_component(self, **kwargs)
+
+    return res
+
+
+def Block_get_config(self):
+    config = original_Block_get_config(self)
+
+    webui_tooltip = getattr(self, 'webui_tooltip', None)
+    if webui_tooltip:
+        config["webui_tooltip"] = webui_tooltip
+
+    return config
+
+
+def BlockContext_init(self, *args, **kwargs):
+    res = original_BlockContext_init(self, *args, **kwargs)
+
+    add_classes_to_gradio_component(self)
+
+    return res
+
+
+original_IOComponent_init = gr.components.IOComponent.__init__
+original_Block_get_config = gr.blocks.Block.get_config
+original_BlockContext_init = gr.blocks.BlockContext.__init__
+
+gr.components.IOComponent.__init__ = IOComponent_init
+gr.blocks.Block.get_config = Block_get_config
+gr.blocks.BlockContext.__init__ = BlockContext_init

+ 2 - 3
modules/hypernetworks/hypernetwork.py

@@ -10,7 +10,7 @@ import torch
 import tqdm
 from einops import rearrange, repeat
 from ldm.util import default
-from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
+from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
 from modules.textual_inversion import textual_inversion, logging
 from modules.textual_inversion.learn_schedule import LearnRateScheduler
 from torch import einsum
@@ -469,8 +469,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
 
 
 def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
-    # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
-    from modules import images
+    from modules import images, processing
 
     save_hypernetwork_every = save_hypernetwork_every or 0
     create_image_every = create_image_every or 0

+ 1 - 1
modules/images.py

@@ -318,7 +318,7 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None):
     return res
 
 
-invalid_filename_chars = '<>:"/\\|?*\n'
+invalid_filename_chars = '<>:"/\\|?*\n\r\t'
 invalid_filename_prefix = ' '
 invalid_filename_postfix = ' .'
 re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')

+ 2 - 4
modules/img2img.py

@@ -3,7 +3,7 @@ from contextlib import closing
 from pathlib import Path
 
 import numpy as np
-from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
+from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
 import gradio as gr
 
 from modules import sd_samplers, images as imgutil
@@ -129,9 +129,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
         mask = None
     elif mode == 2:  # inpaint
         image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
-        alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
-        mask = mask.convert('L').point(lambda x: 255 if x > 128 else 0, mode='1')
-        mask = ImageChops.lighter(alpha_mask, mask).convert('L')
+        mask = mask.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
         image = image.convert("RGB")
     elif mode == 3:  # inpaint sketch
         image = inpaint_color_sketch

+ 71 - 40
modules/processing.py

@@ -30,6 +30,7 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
 from einops import repeat, rearrange
 from blendmodes.blend import blendLayers, BlendType
 
+decode_first_stage = sd_samplers_common.decode_first_stage
 
 # some of those options should not be changed at all because they would break the model, so I removed them from options.
 opt_C = 4
@@ -492,7 +493,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
         noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
 
         subnoise = None
-        if subseeds is not None:
+        if subseeds is not None and subseed_strength != 0:
             subseed = 0 if i >= len(subseeds) else subseeds[i]
 
             subnoise = devices.randn(subseed, noise_shape)
@@ -524,7 +525,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
             cnt = p.sampler.number_of_needed_noises(p)
 
             if eta_noise_seed_delta > 0:
-                torch.manual_seed(seed + eta_noise_seed_delta)
+                devices.manual_seed(seed + eta_noise_seed_delta)
 
             for j in range(cnt):
                 sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
@@ -538,8 +539,12 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
     return x
 
 
+class DecodedSamples(list):
+    already_decoded = True
+
+
 def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
-    samples = []
+    samples = DecodedSamples()
 
     for i in range(batch.shape[0]):
         sample = decode_first_stage(model, batch[i:i + 1])[0]
@@ -572,12 +577,6 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
     return samples
 
 
-def decode_first_stage(model, x):
-    x = model.decode_first_stage(x.to(devices.dtype_vae))
-
-    return x
-
-
 def get_fixed_seed(seed):
     if seed is None or seed == '' or seed == -1:
         return int(random.randrange(4294967294))
@@ -636,7 +635,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
         "Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
         "Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
         "Init image hash": getattr(p, 'init_img_hash', None),
-        "RNG": opts.randn_source if opts.randn_source != "GPU" else None,
+        "RNG": opts.randn_source if opts.randn_source != "GPU" and opts.randn_source != "NV" else None,
         "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
         **p.extra_generation_params,
         "Version": program_version() if opts.add_version_to_infotext else None,
@@ -793,7 +792,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
             with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
                 samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
 
-            x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
+            if getattr(samples_ddim, 'already_decoded', False):
+                x_samples_ddim = samples_ddim
+            else:
+                x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
+
             x_samples_ddim = torch.stack(x_samples_ddim).float()
             x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
 
@@ -935,7 +938,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
     cached_hr_uc = [None, None]
     cached_hr_c = [None, None]
 
-    def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
+    def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_checkpoint_name: str = None, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
         super().__init__(**kwargs)
         self.enable_hr = enable_hr
         self.denoising_strength = denoising_strength
@@ -946,11 +949,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
         self.hr_resize_y = hr_resize_y
         self.hr_upscale_to_x = hr_resize_x
         self.hr_upscale_to_y = hr_resize_y
+        self.hr_checkpoint_name = hr_checkpoint_name
+        self.hr_checkpoint_info = None
         self.hr_sampler_name = hr_sampler_name
         self.hr_prompt = hr_prompt
         self.hr_negative_prompt = hr_negative_prompt
         self.all_hr_prompts = None
         self.all_hr_negative_prompts = None
+        self.latent_scale_mode = None
 
         if firstphase_width != 0 or firstphase_height != 0:
             self.hr_upscale_to_x = self.width
@@ -973,6 +979,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
 
     def init(self, all_prompts, all_seeds, all_subseeds):
         if self.enable_hr:
+            if self.hr_checkpoint_name:
+                self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
+
+                if self.hr_checkpoint_info is None:
+                    raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}')
+
+                self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title
+
             if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
                 self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
 
@@ -982,6 +996,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
             if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
                 self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
 
+            self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
+            if self.enable_hr and self.latent_scale_mode is None:
+                if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
+                    raise Exception(f"could not find upscaler named {self.hr_upscaler}")
+
             if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
                 self.hr_resize_x = self.width
                 self.hr_resize_y = self.height
@@ -1020,14 +1039,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
                     self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
                     self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
 
-            # special case: the user has chosen to do nothing
-            if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height:
-                self.enable_hr = False
-                self.denoising_strength = None
-                self.extra_generation_params.pop("Hires upscale", None)
-                self.extra_generation_params.pop("Hires resize", None)
-                return
-
             if not state.processing_has_refined_job_count:
                 if state.job_count == -1:
                     state.job_count = self.n_iter
@@ -1045,17 +1056,32 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
     def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
         self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
 
-        latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
-        if self.enable_hr and latent_scale_mode is None:
-            if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
-                raise Exception(f"could not find upscaler named {self.hr_upscaler}")
-
         x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
         samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
+        del x
 
         if not self.enable_hr:
             return samples
 
+        if self.latent_scale_mode is None:
+            decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
+        else:
+            decoded_samples = None
+
+        current = shared.sd_model.sd_checkpoint_info
+        try:
+            if self.hr_checkpoint_info is not None:
+                self.sampler = None
+                sd_models.reload_model_weights(info=self.hr_checkpoint_info)
+                devices.torch_gc()
+
+            return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
+        finally:
+            self.sampler = None
+            sd_models.reload_model_weights(info=current)
+            devices.torch_gc()
+
+    def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
         self.is_hr_pass = True
 
         target_width = self.hr_upscale_to_x
@@ -1073,11 +1099,18 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
             info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
             images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
 
-        if latent_scale_mode is not None:
+        img2img_sampler_name = self.hr_sampler_name or self.sampler_name
+
+        if self.sampler_name in ['PLMS', 'UniPC']:  # PLMS/UniPC do not support img2img so we just silently switch to DDIM
+            img2img_sampler_name = 'DDIM'
+
+        self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
+
+        if self.latent_scale_mode is not None:
             for i in range(samples.shape[0]):
                 save_intermediate(samples, i)
 
-            samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
+            samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=self.latent_scale_mode["mode"], antialias=self.latent_scale_mode["antialias"])
 
             # Avoid making the inpainting conditioning unless necessary as
             # this does need some extra compute to decode / encode the image again.
@@ -1086,7 +1119,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
             else:
                 image_conditioning = self.txt2img_image_conditioning(samples)
         else:
-            decoded_samples = decode_first_stage(self.sd_model, samples)
             lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
 
             batch_images = []
@@ -1105,6 +1137,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
             decoded_samples = torch.from_numpy(np.array(batch_images))
             decoded_samples = decoded_samples.to(shared.device)
             decoded_samples = 2. * decoded_samples - 1.
+            decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)
 
             samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
 
@@ -1112,19 +1145,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
 
         shared.state.nextjob()
 
-        img2img_sampler_name = self.hr_sampler_name or self.sampler_name
-
-        if self.sampler_name in ['PLMS', 'UniPC']:  # PLMS/UniPC do not support img2img so we just silently switch to DDIM
-            img2img_sampler_name = 'DDIM'
-
-        self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
-
         samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
 
         noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
 
         # GC now before running the next img2img to prevent running out of memory
-        x = None
         devices.torch_gc()
 
         if not self.disable_extra_networks:
@@ -1143,9 +1168,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
 
         sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
 
+        decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
+
         self.is_hr_pass = False
 
-        return samples
+        return decoded_samples
 
     def close(self):
         super().close()
@@ -1184,8 +1211,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
         if self.hr_c is not None:
             return
 
-        self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
-        self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
+        hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
+        hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)
+
+        self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
+        self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
 
     def setup_conds(self):
         super().setup_conds()
@@ -1193,7 +1223,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
         self.hr_uc = None
         self.hr_c = None
 
-        if self.enable_hr:
+        if self.enable_hr and self.hr_checkpoint_info is None:
             if shared.opts.hires_fix_use_firstpass_conds:
                 self.calculate_hr_conds()
 
@@ -1348,6 +1378,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
         image = image.to(shared.device, dtype=devices.dtype_vae)
 
         self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
+        devices.torch_gc()
 
         if self.resize_mode == 3:
             self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")

+ 7 - 2
modules/prompt_parser.py

@@ -20,7 +20,7 @@ prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
         | "(" prompt ":" prompt ")"
         | "[" prompt "]"
 scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER [WHITESPACE] "]"
-alternate: "[" prompt ("|" prompt)+ "]"
+alternate: "[" prompt ("|" [prompt])+ "]"
 WHITESPACE: /\s+/
 plain: /([^\\\[\]():|]|\\.)+/
 %import common.SIGNED_NUMBER -> NUMBER
@@ -53,6 +53,10 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
     [[3, '((a][:b:c '], [10, '((a][:b:c d']]
     >>> g("[a|(b:1.1)]")
     [[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
+    >>> g("[fe|]male")
+    [[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']]
+    >>> g("[fe|||]male")
+    [[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']]
     """
 
     def collect_steps(steps, tree):
@@ -78,7 +82,8 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
                 before, after, _, when, _ = args
                 yield before or () if step <= when else after
             def alternate(self, args):
-                yield next(args[(step - 1)%len(args)])
+                args = ["" if not arg else arg for arg in args]
+                yield args[(step - 1) % len(args)]
             def start(self, args):
                 def flatten(x):
                     if type(x) == str:

+ 102 - 0
modules/rng_philox.py

@@ -0,0 +1,102 @@
+"""RNG imitiating torch cuda randn on CPU. You are welcome.
+
+Usage:
+
+```
+g = Generator(seed=0)
+print(g.randn(shape=(3, 4)))
+```
+
+Expected output:
+```
+[[-0.92466259 -0.42534415 -2.6438457   0.14518388]
+ [-0.12086647 -0.57972564 -0.62285122 -0.32838709]
+ [-1.07454231 -0.36314407 -1.67105067  2.26550497]]
+```
+"""
+
+import numpy as np
+
+philox_m = [0xD2511F53, 0xCD9E8D57]
+philox_w = [0x9E3779B9, 0xBB67AE85]
+
+two_pow32_inv = np.array([2.3283064e-10], dtype=np.float32)
+two_pow32_inv_2pi = np.array([2.3283064e-10 * 6.2831855], dtype=np.float32)
+
+
+def uint32(x):
+    """Converts (N,) np.uint64 array into (2, N) np.unit32 array."""
+    return x.view(np.uint32).reshape(-1, 2).transpose(1, 0)
+
+
+def philox4_round(counter, key):
+    """A single round of the Philox 4x32 random number generator."""
+
+    v1 = uint32(counter[0].astype(np.uint64) * philox_m[0])
+    v2 = uint32(counter[2].astype(np.uint64) * philox_m[1])
+
+    counter[0] = v2[1] ^ counter[1] ^ key[0]
+    counter[1] = v2[0]
+    counter[2] = v1[1] ^ counter[3] ^ key[1]
+    counter[3] = v1[0]
+
+
+def philox4_32(counter, key, rounds=10):
+    """Generates 32-bit random numbers using the Philox 4x32 random number generator.
+
+    Parameters:
+        counter (numpy.ndarray): A 4xN array of 32-bit integers representing the counter values (offset into generation).
+        key (numpy.ndarray): A 2xN array of 32-bit integers representing the key values (seed).
+        rounds (int): The number of rounds to perform.
+
+    Returns:
+        numpy.ndarray: A 4xN array of 32-bit integers containing the generated random numbers.
+    """
+
+    for _ in range(rounds - 1):
+        philox4_round(counter, key)
+
+        key[0] = key[0] + philox_w[0]
+        key[1] = key[1] + philox_w[1]
+
+    philox4_round(counter, key)
+    return counter
+
+
+def box_muller(x, y):
+    """Returns just the first out of two numbers generated by Box–Muller transform algorithm."""
+    u = x * two_pow32_inv + two_pow32_inv / 2
+    v = y * two_pow32_inv_2pi + two_pow32_inv_2pi / 2
+
+    s = np.sqrt(-2.0 * np.log(u))
+
+    r1 = s * np.sin(v)
+    return r1.astype(np.float32)
+
+
+class Generator:
+    """RNG that produces same outputs as torch.randn(..., device='cuda') on CPU"""
+
+    def __init__(self, seed):
+        self.seed = seed
+        self.offset = 0
+
+    def randn(self, shape):
+        """Generate a sequence of n standard normal random variables using the Philox 4x32 random number generator and the Box-Muller transform."""
+
+        n = 1
+        for x in shape:
+            n *= x
+
+        counter = np.zeros((4, n), dtype=np.uint32)
+        counter[0] = self.offset
+        counter[2] = np.arange(n, dtype=np.uint32)  # up to 2^32 numbers can be generated - if you want more you'd need to spill into counter[3]
+        self.offset += 1
+
+        key = np.empty(n, dtype=np.uint64)
+        key.fill(self.seed)
+        key = uint32(key)
+
+        g = philox4_32(counter, key)
+
+        return box_muller(g[0], g[1]).reshape(shape)  # discard g[2] and g[3]

+ 0 - 60
modules/scripts.py

@@ -631,63 +631,3 @@ def reload_script_body_only():
 
 
 reload_scripts = load_scripts  # compatibility alias
-
-
-def add_classes_to_gradio_component(comp):
-    """
-    this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
-    """
-
-    comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
-
-    if getattr(comp, 'multiselect', False):
-        comp.elem_classes.append('multiselect')
-
-
-
-def IOComponent_init(self, *args, **kwargs):
-    self.webui_tooltip = kwargs.pop('tooltip', None)
-
-    if scripts_current is not None:
-        scripts_current.before_component(self, **kwargs)
-
-    script_callbacks.before_component_callback(self, **kwargs)
-
-    res = original_IOComponent_init(self, *args, **kwargs)
-
-    add_classes_to_gradio_component(self)
-
-    script_callbacks.after_component_callback(self, **kwargs)
-
-    if scripts_current is not None:
-        scripts_current.after_component(self, **kwargs)
-
-    return res
-
-
-def Block_get_config(self):
-    config = original_Block_get_config(self)
-
-    webui_tooltip = getattr(self, 'webui_tooltip', None)
-    if webui_tooltip:
-        config["webui_tooltip"] = webui_tooltip
-
-    return config
-
-
-original_IOComponent_init = gr.components.IOComponent.__init__
-original_Block_get_config = gr.components.Block.get_config
-gr.components.IOComponent.__init__ = IOComponent_init
-gr.components.Block.get_config = Block_get_config
-
-
-def BlockContext_init(self, *args, **kwargs):
-    res = original_BlockContext_init(self, *args, **kwargs)
-
-    add_classes_to_gradio_component(self)
-
-    return res
-
-
-original_BlockContext_init = gr.blocks.BlockContext.__init__
-gr.blocks.BlockContext.__init__ = BlockContext_init

+ 3 - 3
modules/sd_hijack.py

@@ -2,7 +2,6 @@ import torch
 from torch.nn.functional import silu
 from types import MethodType
 
-import modules.textual_inversion.textual_inversion
 from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
 from modules.hypernetworks import hypernetwork
 from modules.shared import cmd_opts
@@ -168,12 +167,13 @@ class StableDiffusionModelHijack:
     clip = None
     optimization_method = None
 
-    embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
-
     def __init__(self):
+        import modules.textual_inversion.textual_inversion
+
         self.extra_generation_params = {}
         self.comments = []
 
+        self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
         self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
 
     def apply_optimizations(self, option=None):

+ 2 - 0
modules/sd_hijack_clip.py

@@ -245,6 +245,8 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
                 hashes.append(f"{name}: {shorthash}")
 
             if hashes:
+                if self.hijack.extra_generation_params.get("TI hashes"):
+                    hashes.append(self.hijack.extra_generation_params.get("TI hashes"))
                 self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
 
         if getattr(self.wrapped, 'return_pooled', False):

+ 2 - 2
modules/sd_hijack_optimizations.py

@@ -256,9 +256,9 @@ def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs):
             raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
                                f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
 
-        slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
+        slice_size = q.shape[1] // steps
         for i in range(0, q.shape[1], slice_size):
-            end = i + slice_size
+            end = min(i + slice_size, q.shape[1])
             s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
 
             s2 = s1.softmax(dim=-1, dtype=q.dtype)

+ 17 - 12
modules/sd_models.py

@@ -66,8 +66,9 @@ class CheckpointInfo:
         self.shorthash = self.sha256[0:10] if self.sha256 else None
 
         self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
+        self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]'
 
-        self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
+        self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
 
     def register(self):
         checkpoints_list[self.title] = self
@@ -86,6 +87,7 @@ class CheckpointInfo:
 
         checkpoints_list.pop(self.title, None)
         self.title = f'{self.name} [{self.shorthash}]'
+        self.short_title = f'{self.name_for_extra} [{self.shorthash}]'
         self.register()
 
         return self.shorthash
@@ -106,14 +108,8 @@ def setup_model():
     enable_midas_autodownload()
 
 
-def checkpoint_tiles():
-    def convert(name):
-        return int(name) if name.isdigit() else name.lower()
-
-    def alphanumeric_key(key):
-        return [convert(c) for c in re.split('([0-9]+)', key)]
-
-    return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
+def checkpoint_tiles(use_short=False):
+    return [x.short_title if use_short else x.title for x in checkpoints_list.values()]
 
 
 def list_models():
@@ -136,11 +132,14 @@ def list_models():
     elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
         print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
 
-    for filename in sorted(model_list, key=str.lower):
+    for filename in model_list:
         checkpoint_info = CheckpointInfo(filename)
         checkpoint_info.register()
 
 
+re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")
+
+
 def get_closet_checkpoint_match(search_string):
     checkpoint_info = checkpoint_aliases.get(search_string, None)
     if checkpoint_info is not None:
@@ -150,6 +149,11 @@ def get_closet_checkpoint_match(search_string):
     if found:
         return found[0]
 
+    search_string_without_checksum = re.sub(re_strip_checksum, '', search_string)
+    found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title))
+    if found:
+        return found[0]
+
     return None
 
 
@@ -302,12 +306,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
         sd_models_xl.extend_sdxl(model)
 
     model.load_state_dict(state_dict, strict=False)
-    del state_dict
     timer.record("apply weights to model")
 
     if shared.opts.sd_checkpoint_cache > 0:
         # cache newly loaded model
-        checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
+        checkpoints_loaded[checkpoint_info] = state_dict
+
+    del state_dict
 
     if shared.cmd_opts.opt_channelslast:
         model.to(memory_format=torch.channels_last)

+ 13 - 7
modules/sd_samplers_common.py

@@ -2,10 +2,8 @@ from collections import namedtuple
 import numpy as np
 import torch
 from PIL import Image
-from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd
-
+from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared
 from modules.shared import opts, state
-import modules.shared as shared
 
 SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
 
@@ -37,7 +35,7 @@ def single_sample_to_image(sample, approximation=None):
         x_sample = sample * 1.5
         x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
     else:
-        x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5
+        x_sample = decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5
 
     x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
     x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
@@ -46,6 +44,12 @@ def single_sample_to_image(sample, approximation=None):
     return Image.fromarray(x_sample)
 
 
+def decode_first_stage(model, x):
+    x = model.decode_first_stage(x.to(devices.dtype_vae))
+
+    return x
+
+
 def sample_to_image(samples, index=0, approximation=None):
     return single_sample_to_image(samples[index], approximation)
 
@@ -85,11 +89,13 @@ class InterruptedException(BaseException):
     pass
 
 
-if opts.randn_source == "CPU":
+def replace_torchsde_browinan():
     import torchsde._brownian.brownian_interval
 
     def torchsde_randn(size, dtype, device, seed):
-        generator = torch.Generator(devices.cpu).manual_seed(int(seed))
-        return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
+        return devices.randn_local(seed, size).to(device=device, dtype=dtype)
 
     torchsde._brownian.brownian_interval._randn = torchsde_randn
+
+
+replace_torchsde_browinan()

+ 5 - 4
modules/sd_samplers_kdiffusion.py

@@ -30,6 +30,7 @@ samplers_k_diffusion = [
     ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
     ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
     ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
+    ('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}),
     ('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras'}),
 ]
 
@@ -260,10 +261,7 @@ class TorchHijack:
             if noise.shape == x.shape:
                 return noise
 
-        if opts.randn_source == "CPU" or x.device.type == 'mps':
-            return torch.randn_like(x, device=devices.cpu).to(x.device)
-        else:
-            return torch.randn_like(x)
+        return devices.randn_like(x)
 
 
 class KDiffusionSampler:
@@ -378,6 +376,9 @@ class KDiffusionSampler:
             sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
 
             sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
+        elif self.config is not None and self.config.options.get('scheduler', None) == 'exponential':
+            m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
+            sigmas = k_diffusion.sampling.get_sigmas_exponential(n=steps, sigma_min=m_sigma_min, sigma_max=m_sigma_max, device=shared.device)
         else:
             sigmas = self.model_wrap.get_sigmas(steps)
 

+ 15 - 1
modules/sd_vae.py

@@ -1,6 +1,6 @@
 import os
 import collections
-from modules import paths, shared, devices, script_callbacks, sd_models
+from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks
 import glob
 from copy import deepcopy
 
@@ -16,6 +16,7 @@ checkpoint_info = None
 
 checkpoints_loaded = collections.OrderedDict()
 
+
 def get_base_vae(model):
     if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
         return base_vae
@@ -50,6 +51,7 @@ def get_filename(filepath):
 
 
 def refresh_vae_list():
+    global vae_dict
     vae_dict.clear()
 
     paths = [
@@ -83,6 +85,8 @@ def refresh_vae_list():
         name = get_filename(filepath)
         vae_dict[name] = filepath
 
+    vae_dict = dict(sorted(vae_dict.items(), key=lambda item: shared.natural_sort_key(item[0])))
+
 
 def find_vae_near_checkpoint(checkpoint_file):
     checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0]
@@ -97,6 +101,16 @@ def resolve_vae(checkpoint_file):
     if shared.cmd_opts.vae_path is not None:
         return shared.cmd_opts.vae_path, 'from commandline argument'
 
+    metadata = extra_networks.get_user_metadata(checkpoint_file)
+    vae_metadata = metadata.get("vae", None)
+    if vae_metadata is not None and vae_metadata != "Automatic":
+        if vae_metadata == "None":
+            return None, None
+
+        vae_from_metadata = vae_dict.get(vae_metadata, None)
+        if vae_from_metadata is not None:
+            return vae_from_metadata, "from user metadata"
+
     is_automatic = shared.opts.sd_vae in {"Automatic", "auto"}  # "auto" for people with old config
 
     vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)

+ 35 - 18
modules/shared.py

@@ -220,12 +220,19 @@ class State:
             return
 
         import modules.sd_samplers
-        if opts.show_progress_grid:
-            self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
-        else:
-            self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
 
-        self.current_image_sampling_step = self.sampling_step
+        try:
+            if opts.show_progress_grid:
+                self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
+            else:
+                self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
+
+            self.current_image_sampling_step = self.sampling_step
+
+        except Exception:
+            # when switching models during genration, VAE would be on CPU, so creating an image will fail.
+            # we silently ignore this error
+            errors.record_exception()
 
     def assign_current_image(self, image):
         self.current_image = image
@@ -385,7 +392,8 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
 }))
 
 options_templates.update(options_section(('system', "System"), {
-    "show_warnings": OptionInfo(False, "Show warnings in console."),
+    "show_warnings": OptionInfo(False, "Show warnings in console.").needs_restart(),
+    "show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_restart(),
     "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
     "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
     "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
@@ -419,19 +427,14 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
     "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
     "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
     "sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
-    "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
-    "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
-    "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
-    "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
-    "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}),
-    "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
+    "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds").needs_restart(),
     "enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
     "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
     "comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
     "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
     "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
     "auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
-    "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
+    "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
 }))
 
 options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
@@ -441,6 +444,22 @@ options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
     "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
 }))
 
+
+options_templates.update(options_section(('img2img', "img2img"), {
+    "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+    "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
+    "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
+    "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
+    "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill transparent parts of the input image with this color.", ui_components.FormColorPicker, {}),
+    "img2img_editor_height": OptionInfo(720, "Height of the image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_restart(),
+    "img2img_sketch_default_brush_color": OptionInfo("#ffffff", "Sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img sketch").needs_restart(),
+    "img2img_inpaint_mask_brush_color": OptionInfo("#ffffff", "Inpaint mask brush color", ui_components.FormColorPicker,  {}).info("brush color of inpaint mask").needs_restart(),
+    "img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_restart(),
+    "return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
+    "return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
+}))
+
+
 options_templates.update(options_section(('optimizations', "Optimizations"), {
     "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
     "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
@@ -460,7 +479,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
     "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
 }))
 
-options_templates.update(options_section(('interrogate', "Interrogate Options"), {
+options_templates.update(options_section(('interrogate', "Interrogate"), {
     "interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"),
     "interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"),
     "interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
@@ -493,10 +512,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
 options_templates.update(options_section(('ui', "User interface"), {
     "localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_restart(),
     "gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).needs_restart(),
-    "img2img_editor_height": OptionInfo(720, "img2img: height of image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_restart(),
     "return_grid": OptionInfo(True, "Show grid in results for web"),
-    "return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
-    "return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
     "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
     "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
     "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
@@ -515,11 +531,12 @@ options_templates.update(options_section(('ui', "User interface"), {
     "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
     "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
     "ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_restart(),
-    "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires sampler selection").needs_restart(),
+    "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_restart(),
     "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_restart(),
     "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_restart(),
 }))
 
+
 options_templates.update(options_section(('infotext', "Infotext"), {
     "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
     "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),

+ 1 - 4
modules/styles.py

@@ -106,10 +106,7 @@ class StyleDatabase:
         if os.path.exists(path):
             shutil.copy(path, f"{path}.bak")
 
-        fd = os.open(path, os.O_RDWR | os.O_CREAT)
-        with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
-            # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
-            # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
+        with open(path, "w", encoding="utf-8-sig", newline='') as file:
             writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
             writer.writeheader()
             writer.writerows(style._asdict() for k, style in self.styles.items())

+ 3 - 1
modules/textual_inversion/textual_inversion.py

@@ -13,7 +13,7 @@ import numpy as np
 from PIL import Image, PngImagePlugin
 from torch.utils.tensorboard import SummaryWriter
 
-from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
+from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
 import modules.textual_inversion.dataset
 from modules.textual_inversion.learn_schedule import LearnRateScheduler
 
@@ -387,6 +387,8 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
 
 
 def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+    from modules import processing
+
     save_embedding_every = save_embedding_every or 0
     create_image_every = create_image_every or 0
     template_file = textual_inversion_templates.get(template_filename, None)

+ 2 - 1
modules/txt2img.py

@@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html
 import gradio as gr
 
 
-def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
+def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
     override_settings = create_override_settings_dict(override_settings_texts)
 
     p = processing.StableDiffusionProcessingTxt2Img(
@@ -41,6 +41,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
         hr_second_pass_steps=hr_second_pass_steps,
         hr_resize_x=hr_resize_x,
         hr_resize_y=hr_resize_y,
+        hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
         hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None,
         hr_prompt=hr_prompt,
         hr_negative_prompt=hr_negative_prompt,

+ 146 - 206
modules/ui.py

@@ -12,34 +12,30 @@ import numpy as np
 from PIL import Image, PngImagePlugin  # noqa: F401
 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
 
-from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger
+from modules import gradio_extensons  # noqa: F401
+from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts
 from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
 from modules.paths import script_path
 from modules.ui_common import create_refresh_button
 from modules.ui_gradio_extensions import reload_javascript
 
-
 from modules.shared import opts, cmd_opts
 
-import modules.codeformer_model
 import modules.generation_parameters_copypaste as parameters_copypaste
-import modules.gfpgan_model
-import modules.hypernetworks.ui
-import modules.scripts
+import modules.hypernetworks.ui as hypernetworks_ui
+import modules.textual_inversion.ui as textual_inversion_ui
+import modules.textual_inversion.textual_inversion as textual_inversion
 import modules.shared as shared
-import modules.styles
-import modules.textual_inversion.ui
+import modules.images
 from modules import prompt_parser
 from modules.sd_hijack import model_hijack
 from modules.sd_samplers import samplers, samplers_for_img2img
-from modules.textual_inversion import textual_inversion
-import modules.hypernetworks.ui
 from modules.generation_parameters_copypaste import image_from_url_text
-import modules.extras
 
 create_setting_component = ui_settings.create_setting_component
 
 warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
+warnings.filterwarnings("default" if opts.show_gradio_deprecation_warnings else "ignore", category=gr.deprecation.GradioDeprecationWarning)
 
 # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
 mimetypes.init()
@@ -92,19 +88,6 @@ def send_gradio_gallery_to_image(x):
     return image_from_url_text(x[0])
 
 
-def add_style(name: str, prompt: str, negative_prompt: str):
-    if name is None:
-        return [gr_show() for x in range(4)]
-
-    style = modules.styles.PromptStyle(name, prompt, negative_prompt)
-    shared.prompt_styles.styles[style.name] = style
-    # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we
-    # reserialize all styles every time we save them
-    shared.prompt_styles.save_styles(shared.styles_filename)
-
-    return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(2)]
-
-
 def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
     from modules import processing, devices
 
@@ -129,13 +112,6 @@ def resize_from_to_html(width, height, scale_by):
     return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{target_width}x{target_height}</span>"
 
 
-def apply_styles(prompt, prompt_neg, styles):
-    prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
-    prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles)
-
-    return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])]
-
-
 def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles):
     if mode in {0, 1, 3, 4}:
         return [interrogation_function(ii_singles[mode]), None]
@@ -172,7 +148,6 @@ def interrogate_deepbooru(image):
 def create_seed_inputs(target_interface):
     with FormRow(elem_id=f"{target_interface}_seed_row", variant="compact"):
         seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=f"{target_interface}_seed")
-        seed.style(container=False)
         random_seed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_seed", label='Random seed')
         reuse_seed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_seed", label='Reuse seed')
 
@@ -184,7 +159,6 @@ def create_seed_inputs(target_interface):
     with FormRow(visible=False, elem_id=f"{target_interface}_subseed_row") as seed_extra_row_1:
         seed_extras.append(seed_extra_row_1)
         subseed = gr.Number(label='Variation seed', value=-1, elem_id=f"{target_interface}_subseed")
-        subseed.style(container=False)
         random_subseed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_subseed")
         reuse_subseed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_subseed")
         subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=f"{target_interface}_subseed_strength")
@@ -267,71 +241,78 @@ def update_token_counter(text, steps):
     return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
 
 
-def create_toprow(is_img2img):
-    id_part = "img2img" if is_img2img else "txt2img"
+class Toprow:
+    """Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation"""
 
-    with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
-        with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
-            with gr.Row():
-                with gr.Column(scale=80):
-                    with gr.Row():
-                        prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
+    def __init__(self, is_img2img):
+        id_part = "img2img" if is_img2img else "txt2img"
+        self.id_part = id_part
 
-            with gr.Row():
-                with gr.Column(scale=80):
-                    with gr.Row():
-                        negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
-
-        button_interrogate = None
-        button_deepbooru = None
-        if is_img2img:
-            with gr.Column(scale=1, elem_classes="interrogate-col"):
-                button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
-                button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
-
-        with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
-            with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
-                interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
-                skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
-                submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
-
-                skip.click(
-                    fn=lambda: shared.state.skip(),
-                    inputs=[],
-                    outputs=[],
-                )
+        with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
+            with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
+                with gr.Row():
+                    with gr.Column(scale=80):
+                        with gr.Row():
+                            self.prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
+                            self.prompt_img = gr.File(label="", elem_id=f"{id_part}_prompt_image", file_count="single", type="binary", visible=False)
 
-                interrupt.click(
-                    fn=lambda: shared.state.interrupt(),
-                    inputs=[],
-                    outputs=[],
-                )
+                with gr.Row():
+                    with gr.Column(scale=80):
+                        with gr.Row():
+                            self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
 
-            with gr.Row(elem_id=f"{id_part}_tools"):
-                paste = ToolButton(value=paste_symbol, elem_id="paste")
-                clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
-                extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
-                prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply")
-                save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create")
-                restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
-
-                token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
-                token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
-                negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
-                negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
-
-                clear_prompt_button.click(
-                    fn=lambda *x: x,
-                    _js="confirm_clear_prompt",
-                    inputs=[prompt, negative_prompt],
-                    outputs=[prompt, negative_prompt],
-                )
 
-            with gr.Row(elem_id=f"{id_part}_styles_row"):
-                prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
-                create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")
+            self.button_interrogate = None
+            self.button_deepbooru = None
+            if is_img2img:
+                with gr.Column(scale=1, elem_classes="interrogate-col"):
+                    self.button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
+                    self.button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
 
-    return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button
+            with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
+                with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
+                    self.interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
+                    self.skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
+                    self.submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
+
+                    self.skip.click(
+                        fn=lambda: shared.state.skip(),
+                        inputs=[],
+                        outputs=[],
+                    )
+
+                    self.interrupt.click(
+                        fn=lambda: shared.state.interrupt(),
+                        inputs=[],
+                        outputs=[],
+                    )
+
+                with gr.Row(elem_id=f"{id_part}_tools"):
+                    self.paste = ToolButton(value=paste_symbol, elem_id="paste")
+                    self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
+                    self.extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
+                    self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
+
+                    self.token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
+                    self.token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
+                    self.negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
+                    self.negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
+
+                    self.clear_prompt_button.click(
+                        fn=lambda *x: x,
+                        _js="confirm_clear_prompt",
+                        inputs=[self.prompt, self.negative_prompt],
+                        outputs=[self.prompt, self.negative_prompt],
+                    )
+
+                self.ui_styles = ui_prompt_styles.UiPromptStyles(id_part, self.prompt, self.negative_prompt)
+
+        self.prompt_img.change(
+            fn=modules.images.image_data,
+            inputs=[self.prompt_img],
+            outputs=[self.prompt, self.prompt_img],
+            show_progress=False,
+        )
 
 
 def setup_progressbar(*args, **kwargs):
@@ -415,22 +396,21 @@ def create_ui():
 
     parameters_copypaste.reset()
 
-    modules.scripts.scripts_current = modules.scripts.scripts_txt2img
-    modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
+    scripts.scripts_current = scripts.scripts_txt2img
+    scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
 
     with gr.Blocks(analytics_enabled=False) as txt2img_interface:
-        txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=False)
+        toprow = Toprow(is_img2img=False)
 
         dummy_component = gr.Label(visible=False)
-        txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
 
         with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks:
             from modules import ui_extra_networks
-            extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img')
+            extra_networks_ui = ui_extra_networks.create_ui(extra_networks, toprow.extra_networks_button, 'txt2img')
 
-        with gr.Row().style(equal_height=False):
+        with gr.Row(equal_height=False):
             with gr.Column(variant='compact', elem_id="txt2img_settings"):
-                modules.scripts.scripts_txt2img.prepare_ui()
+                scripts.scripts_txt2img.prepare_ui()
 
                 for category in ordered_ui_categories():
                     if category == "sampler":
@@ -476,6 +456,10 @@ def create_ui():
                                 hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
 
                             with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
+
+                                hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
+                                create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
+
                                 hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index")
 
                             with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
@@ -498,10 +482,10 @@ def create_ui():
 
                     elif category == "scripts":
                         with FormGroup(elem_id="txt2img_script_container"):
-                            custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
+                            custom_inputs = scripts.scripts_txt2img.setup_ui()
 
                     else:
-                        modules.scripts.scripts_txt2img.setup_ui_for_section(category)
+                        scripts.scripts_txt2img.setup_ui_for_section(category)
 
             hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
 
@@ -532,9 +516,9 @@ def create_ui():
                 _js="submit",
                 inputs=[
                     dummy_component,
-                    txt2img_prompt,
-                    txt2img_negative_prompt,
-                    txt2img_prompt_styles,
+                    toprow.prompt,
+                    toprow.negative_prompt,
+                    toprow.ui_styles.dropdown,
                     steps,
                     sampler_index,
                     restore_faces,
@@ -553,6 +537,7 @@ def create_ui():
                     hr_second_pass_steps,
                     hr_resize_x,
                     hr_resize_y,
+                    hr_checkpoint_name,
                     hr_sampler_index,
                     hr_prompt,
                     hr_negative_prompt,
@@ -569,12 +554,12 @@ def create_ui():
                 show_progress=False,
             )
 
-            txt2img_prompt.submit(**txt2img_args)
-            submit.click(**txt2img_args)
+            toprow.prompt.submit(**txt2img_args)
+            toprow.submit.click(**txt2img_args)
 
             res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False)
 
-            restore_progress_button.click(
+            toprow.restore_progress_button.click(
                 fn=progress.restore_progress,
                 _js="restoreProgressTxt2img",
                 inputs=[dummy_component],
@@ -587,18 +572,6 @@ def create_ui():
                 show_progress=False,
             )
 
-            txt_prompt_img.change(
-                fn=modules.images.image_data,
-                inputs=[
-                    txt_prompt_img
-                ],
-                outputs=[
-                    txt2img_prompt,
-                    txt_prompt_img
-                ],
-                show_progress=False,
-            )
-
             enable_hr.change(
                 fn=lambda x: gr_show(x),
                 inputs=[enable_hr],
@@ -607,8 +580,8 @@ def create_ui():
             )
 
             txt2img_paste_fields = [
-                (txt2img_prompt, "Prompt"),
-                (txt2img_negative_prompt, "Negative prompt"),
+                (toprow.prompt, "Prompt"),
+                (toprow.negative_prompt, "Negative prompt"),
                 (steps, "Steps"),
                 (sampler_index, "Sampler"),
                 (restore_faces, "Face restoration"),
@@ -617,34 +590,36 @@ def create_ui():
                 (width, "Size-1"),
                 (height, "Size-2"),
                 (batch_size, "Batch size"),
+                (seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
                 (subseed, "Variation seed"),
                 (subseed_strength, "Variation seed strength"),
                 (seed_resize_from_w, "Seed resize from-1"),
                 (seed_resize_from_h, "Seed resize from-2"),
-                (txt2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
+                (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
                 (denoising_strength, "Denoising strength"),
-                (enable_hr, lambda d: "Denoising strength" in d),
-                (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
+                (enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)),
+                (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d))),
                 (hr_scale, "Hires upscale"),
                 (hr_upscaler, "Hires upscaler"),
                 (hr_second_pass_steps, "Hires steps"),
                 (hr_resize_x, "Hires resize-1"),
                 (hr_resize_y, "Hires resize-2"),
+                (hr_checkpoint_name, "Hires checkpoint"),
                 (hr_sampler_index, "Hires sampler"),
-                (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" else gr.update()),
+                (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
                 (hr_prompt, "Hires prompt"),
                 (hr_negative_prompt, "Hires negative prompt"),
                 (hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
-                *modules.scripts.scripts_txt2img.infotext_fields
+                *scripts.scripts_txt2img.infotext_fields
             ]
             parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
             parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
-                paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None,
+                paste_button=toprow.paste, tabname="txt2img", source_text_component=toprow.prompt, source_image_component=None,
             ))
 
             txt2img_preview_params = [
-                txt2img_prompt,
-                txt2img_negative_prompt,
+                toprow.prompt,
+                toprow.negative_prompt,
                 steps,
                 sampler_index,
                 cfg_scale,
@@ -653,24 +628,22 @@ def create_ui():
                 height,
             ]
 
-            token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
-            negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
+            toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
+            toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
 
             ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
 
-    modules.scripts.scripts_current = modules.scripts.scripts_img2img
-    modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
+    scripts.scripts_current = scripts.scripts_img2img
+    scripts.scripts_img2img.initialize_scripts(is_img2img=True)
 
     with gr.Blocks(analytics_enabled=False) as img2img_interface:
-        img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=True)
-
-        img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
+        toprow = Toprow(is_img2img=True)
 
         with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks:
             from modules import ui_extra_networks
-            extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'img2img')
+            extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, toprow.extra_networks_button, 'img2img')
 
-        with FormRow().style(equal_height=False):
+        with FormRow(equal_height=False):
             with gr.Column(variant='compact', elem_id="img2img_settings"):
                 copy_image_buttons = []
                 copy_image_destinations = {}
@@ -692,19 +665,19 @@ def create_ui():
                     img2img_selected_tab = gr.State(0)
 
                     with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
-                        init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=opts.img2img_editor_height)
+                        init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height)
                         add_copy_image_controls('img2img', init_img)
 
                     with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
-                        sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=opts.img2img_editor_height)
+                        sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)
                         add_copy_image_controls('sketch', sketch)
 
                     with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
-                        init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=opts.img2img_editor_height)
+                        init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color)
                         add_copy_image_controls('inpaint', init_img_with_mask)
 
                     with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
-                        inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=opts.img2img_editor_height)
+                        inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)
                         inpaint_color_sketch_orig = gr.State(None)
                         add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
 
@@ -764,7 +737,7 @@ def create_ui():
                 with FormRow():
                     resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
 
-                modules.scripts.scripts_img2img.prepare_ui()
+                scripts.scripts_img2img.prepare_ui()
 
                 for category in ordered_ui_categories():
                     if category == "sampler":
@@ -845,7 +818,7 @@ def create_ui():
 
                     elif category == "scripts":
                         with FormGroup(elem_id="img2img_script_container"):
-                            custom_inputs = modules.scripts.scripts_img2img.setup_ui()
+                            custom_inputs = scripts.scripts_img2img.setup_ui()
 
                     elif category == "inpaint":
                         with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
@@ -876,34 +849,22 @@ def create_ui():
                                     outputs=[inpaint_controls, mask_alpha],
                                 )
                     else:
-                        modules.scripts.scripts_img2img.setup_ui_for_section(category)
+                        scripts.scripts_img2img.setup_ui_for_section(category)
 
             img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
 
             connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
             connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
 
-            img2img_prompt_img.change(
-                fn=modules.images.image_data,
-                inputs=[
-                    img2img_prompt_img
-                ],
-                outputs=[
-                    img2img_prompt,
-                    img2img_prompt_img
-                ],
-                show_progress=False,
-            )
-
             img2img_args = dict(
                 fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
                 _js="submit_img2img",
                 inputs=[
                     dummy_component,
                     dummy_component,
-                    img2img_prompt,
-                    img2img_negative_prompt,
-                    img2img_prompt_styles,
+                    toprow.prompt,
+                    toprow.negative_prompt,
+                    toprow.ui_styles.dropdown,
                     init_img,
                     sketch,
                     init_img_with_mask,
@@ -962,11 +923,11 @@ def create_ui():
                     inpaint_color_sketch,
                     init_img_inpaint,
                 ],
-                outputs=[img2img_prompt, dummy_component],
+                outputs=[toprow.prompt, dummy_component],
             )
 
-            img2img_prompt.submit(**img2img_args)
-            submit.click(**img2img_args)
+            toprow.prompt.submit(**img2img_args)
+            toprow.submit.click(**img2img_args)
 
             res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('img2img')}", inputs=None, outputs=None, show_progress=False)
 
@@ -978,7 +939,7 @@ def create_ui():
                 show_progress=False,
             )
 
-            restore_progress_button.click(
+            toprow.restore_progress_button.click(
                 fn=progress.restore_progress,
                 _js="restoreProgressImg2img",
                 inputs=[dummy_component],
@@ -991,46 +952,24 @@ def create_ui():
                 show_progress=False,
             )
 
-            img2img_interrogate.click(
+            toprow.button_interrogate.click(
                 fn=lambda *args: process_interrogate(interrogate, *args),
                 **interrogate_args,
             )
 
-            img2img_deepbooru.click(
+            toprow.button_deepbooru.click(
                 fn=lambda *args: process_interrogate(interrogate_deepbooru, *args),
                 **interrogate_args,
             )
 
-            prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
-            style_dropdowns = [txt2img_prompt_styles, img2img_prompt_styles]
-            style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
-
-            for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
-                button.click(
-                    fn=add_style,
-                    _js="ask_for_style_name",
-                    # Have to pass empty dummy component here, because the JavaScript and Python function have to accept
-                    # the same number of parameters, but we only know the style-name after the JavaScript prompt
-                    inputs=[dummy_component, prompt, negative_prompt],
-                    outputs=[txt2img_prompt_styles, img2img_prompt_styles],
-                )
-
-            for button, (prompt, negative_prompt), styles, js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
-                button.click(
-                    fn=apply_styles,
-                    _js=js_func,
-                    inputs=[prompt, negative_prompt, styles],
-                    outputs=[prompt, negative_prompt, styles],
-                )
-
-            token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
-            negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[img2img_negative_prompt, steps], outputs=[negative_token_counter])
+            toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
+            toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
 
             ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
 
             img2img_paste_fields = [
-                (img2img_prompt, "Prompt"),
-                (img2img_negative_prompt, "Negative prompt"),
+                (toprow.prompt, "Prompt"),
+                (toprow.negative_prompt, "Negative prompt"),
                 (steps, "Steps"),
                 (sampler_index, "Sampler"),
                 (restore_faces, "Face restoration"),
@@ -1040,28 +979,29 @@ def create_ui():
                 (width, "Size-1"),
                 (height, "Size-2"),
                 (batch_size, "Batch size"),
+                (seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
                 (subseed, "Variation seed"),
                 (subseed_strength, "Variation seed strength"),
                 (seed_resize_from_w, "Seed resize from-1"),
                 (seed_resize_from_h, "Seed resize from-2"),
-                (img2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
+                (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
                 (denoising_strength, "Denoising strength"),
                 (mask_blur, "Mask blur"),
-                *modules.scripts.scripts_img2img.infotext_fields
+                *scripts.scripts_img2img.infotext_fields
             ]
             parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
             parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings)
             parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
-                paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None,
+                paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None,
             ))
 
-    modules.scripts.scripts_current = None
+    scripts.scripts_current = None
 
     with gr.Blocks(analytics_enabled=False) as extras_interface:
         ui_postprocessing.create_ui()
 
     with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
-        with gr.Row().style(equal_height=False):
+        with gr.Row(equal_height=False):
             with gr.Column(variant='panel'):
                 image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil")
 
@@ -1086,10 +1026,10 @@ def create_ui():
     modelmerger_ui = ui_checkpoint_merger.UiCheckpointMerger()
 
     with gr.Blocks(analytics_enabled=False) as train_interface:
-        with gr.Row().style(equal_height=False):
+        with gr.Row(equal_height=False):
             gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
 
-        with gr.Row(variant="compact").style(equal_height=False):
+        with gr.Row(variant="compact", equal_height=False):
             with gr.Tabs(elem_id="train_tabs"):
 
                 with gr.Tab(label="Create embedding", id="create_embedding"):
@@ -1109,7 +1049,7 @@ def create_ui():
                     new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
                     new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
                     new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
-                    new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func")
+                    new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=hypernetworks_ui.keys, elem_id="train_new_hypernetwork_activation_func")
                     new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")
                     new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")
                     new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")
@@ -1249,12 +1189,12 @@ def create_ui():
 
             with gr.Column(elem_id='ti_gallery_container'):
                 ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
-                gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(columns=4)
+                gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery', columns=4)
                 gr.HTML(elem_id="ti_progress", value="")
                 ti_outcome = gr.HTML(elem_id="ti_error", value="")
 
         create_embedding.click(
-            fn=modules.textual_inversion.ui.create_embedding,
+            fn=textual_inversion_ui.create_embedding,
             inputs=[
                 new_embedding_name,
                 initialization_text,
@@ -1269,7 +1209,7 @@ def create_ui():
         )
 
         create_hypernetwork.click(
-            fn=modules.hypernetworks.ui.create_hypernetwork,
+            fn=hypernetworks_ui.create_hypernetwork,
             inputs=[
                 new_hypernetwork_name,
                 new_hypernetwork_sizes,
@@ -1289,7 +1229,7 @@ def create_ui():
         )
 
         run_preprocess.click(
-            fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
+            fn=wrap_gradio_gpu_call(textual_inversion_ui.preprocess, extra_outputs=[gr.update()]),
             _js="start_training_textual_inversion",
             inputs=[
                 dummy_component,
@@ -1325,7 +1265,7 @@ def create_ui():
         )
 
         train_embedding.click(
-            fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
+            fn=wrap_gradio_gpu_call(textual_inversion_ui.train_embedding, extra_outputs=[gr.update()]),
             _js="start_training_textual_inversion",
             inputs=[
                 dummy_component,
@@ -1359,7 +1299,7 @@ def create_ui():
         )
 
         train_hypernetwork.click(
-            fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]),
+            fn=wrap_gradio_gpu_call(hypernetworks_ui.train_hypernetwork, extra_outputs=[gr.update()]),
             _js="start_training_textual_inversion",
             inputs=[
                 dummy_component,

+ 1 - 1
modules/ui_checkpoint_merger.py

@@ -29,7 +29,7 @@ def modelmerger(*args):
 class UiCheckpointMerger:
     def __init__(self):
         with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
-            with gr.Row().style(equal_height=False):
+            with gr.Row(equal_height=False):
                 with gr.Column(variant='compact'):
                     self.interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")
 

+ 29 - 5
modules/ui_common.py

@@ -134,7 +134,7 @@ Requested path was: {f}
 
     with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
         with gr.Group(elem_id=f"{tabname}_gallery_container"):
-            result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(columns=4)
+            result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4)
 
         generation_info = None
         with gr.Column():
@@ -223,20 +223,44 @@ Requested path was: {f}
 
 
 def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
+    refresh_components = refresh_component if isinstance(refresh_component, list) else [refresh_component]
+
+    label = None
+    for comp in refresh_components:
+        label = getattr(comp, 'label', None)
+        if label is not None:
+            break
+
     def refresh():
         refresh_method()
         args = refreshed_args() if callable(refreshed_args) else refreshed_args
 
         for k, v in args.items():
-            setattr(refresh_component, k, v)
+            for comp in refresh_components:
+                setattr(comp, k, v)
 
-        return gr.update(**(args or {}))
+        return (gr.update(**(args or {})) for _ in refresh_components) if len(refresh_components) > 1 else gr.update(**(args or {}))
 
-    refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
+    refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id, tooltip=f"{label}: refresh" if label else "Refresh")
     refresh_button.click(
         fn=refresh,
         inputs=[],
-        outputs=[refresh_component]
+        outputs=refresh_components
     )
     return refresh_button
 
+
+def setup_dialog(button_show, dialog, *, button_close=None):
+    """Sets up the UI so that the dialog (gr.Box) is invisible, and is only shown when buttons_show is clicked, in a fullscreen modal window."""
+
+    dialog.visible = False
+
+    button_show.click(
+        fn=lambda: gr.update(visible=True),
+        inputs=[],
+        outputs=[dialog],
+    ).then(fn=None, _js="function(){ popup(gradioApp().getElementById('" + dialog.elem_id + "')); }")
+
+    if button_close:
+        button_close.click(fn=None, _js="closePopup")
+

+ 1 - 1
modules/ui_components.py

@@ -35,7 +35,7 @@ class FormColumn(FormComponent, gr.Column):
 
 
 class FormGroup(FormComponent, gr.Group):
-    """Same as gr.Row but fits inside gradio forms"""
+    """Same as gr.Group but fits inside gradio forms"""
 
     def get_block_name(self):
         return "group"

+ 15 - 11
modules/ui_extensions.py

@@ -164,7 +164,7 @@ def extension_table():
             ext_status = ext.status
 
         style = ""
-        if shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.opts.disable_all_extensions == "all":
+        if shared.cmd_opts.disable_extra_extensions and not ext.is_builtin or shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all":
             style = STYLE_PRIMARY
 
         version_link = ext.version
@@ -533,16 +533,20 @@ def create_ui():
                     apply = gr.Button(value=apply_label, variant="primary")
                     check = gr.Button(value="Check for updates")
                     extensions_disable_all = gr.Radio(label="Disable all extensions", choices=["none", "extra", "all"], value=shared.opts.disable_all_extensions, elem_id="extensions_disable_all")
-                    extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
-                    extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False)
+                    extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False, container=False)
+                    extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False, container=False)
 
                 html = ""
-                if shared.opts.disable_all_extensions != "none":
-                    html = """
-<span style="color: var(--primary-400);">
-    "Disable all extensions" was set, change it to "none" to load all extensions again
-</span>
-                    """
+
+                if shared.cmd_opts.disable_all_extensions or shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions != "none":
+                    if shared.cmd_opts.disable_all_extensions:
+                        msg = '"--disable-all-extensions" was used, remove it to load all extensions again'
+                    elif shared.opts.disable_all_extensions != "none":
+                        msg = '"Disable all extensions" was set, change it to "none" to load all extensions again'
+                    elif shared.cmd_opts.disable_extra_extensions:
+                        msg = '"--disable-extra-extensions" was used, remove it to load all extensions again'
+                    html = f'<span style="color: var(--primary-400);">{msg}</span>'
+
                 info = gr.HTML(html)
                 extensions_table = gr.HTML('Loading...')
                 ui.load(fn=extension_table, inputs=[], outputs=[extensions_table])
@@ -565,7 +569,7 @@ def create_ui():
                 with gr.Row():
                     refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
                     extensions_index_url = os.environ.get('WEBUI_EXTENSIONS_INDEX', "https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json")
-                    available_extensions_index = gr.Text(value=extensions_index_url, label="Extension index URL").style(container=False)
+                    available_extensions_index = gr.Text(value=extensions_index_url, label="Extension index URL", container=False)
                     extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
                     install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
 
@@ -574,7 +578,7 @@ def create_ui():
                     sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index")
 
                 with gr.Row():
-                    search_extensions_text = gr.Text(label="Search").style(container=False)
+                    search_extensions_text = gr.Text(label="Search", container=False)
 
                 install_result = gr.HTML()
                 available_extensions_table = gr.HTML()

+ 2 - 11
modules/ui_extra_networks.py

@@ -2,7 +2,7 @@ import os.path
 import urllib.parse
 from pathlib import Path
 
-from modules import shared, ui_extra_networks_user_metadata, errors
+from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks
 from modules.images import read_info_from_image, save_image_with_geninfo
 from modules.ui import up_down_symbol
 import gradio as gr
@@ -101,16 +101,7 @@ class ExtraNetworksPage:
 
     def read_user_metadata(self, item):
         filename = item.get("filename", None)
-        basename, ext = os.path.splitext(filename)
-        metadata_filename = basename + '.json'
-
-        metadata = {}
-        try:
-            if os.path.isfile(metadata_filename):
-                with open(metadata_filename, "r", encoding="utf8") as file:
-                    metadata = json.load(file)
-        except Exception as e:
-            errors.display(e, f"reading extra network user metadata from {metadata_filename}")
+        metadata = extra_networks.get_user_metadata(filename)
 
         desc = metadata.get("description", None)
         if desc is not None:

+ 4 - 1
modules/ui_extra_networks_checkpoints.py

@@ -3,6 +3,7 @@ import os
 
 from modules import shared, ui_extra_networks, sd_models
 from modules.ui_extra_networks import quote_js
+from modules.ui_extra_networks_checkpoints_user_metadata import CheckpointUserMetadataEditor
 
 
 class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
@@ -12,7 +13,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
     def refresh(self):
         shared.refresh_checkpoints()
 
-    def create_item(self, name, index=None):
+    def create_item(self, name, index=None, enable_filter=True):
         checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
         path, ext = os.path.splitext(checkpoint.filename)
         return {
@@ -34,3 +35,5 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
     def allowed_directories_for_previews(self):
         return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
 
+    def create_user_metadata_editor(self, ui, tabname):
+        return CheckpointUserMetadataEditor(ui, tabname, self)

+ 60 - 0
modules/ui_extra_networks_checkpoints_user_metadata.py

@@ -0,0 +1,60 @@
+import gradio as gr
+
+from modules import ui_extra_networks_user_metadata, sd_vae
+from modules.ui_common import create_refresh_button
+
+
+class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor):
+    def __init__(self, ui, tabname, page):
+        super().__init__(ui, tabname, page)
+
+        self.select_vae = None
+
+    def save_user_metadata(self, name, desc, notes, vae):
+        user_metadata = self.get_user_metadata(name)
+        user_metadata["description"] = desc
+        user_metadata["notes"] = notes
+        user_metadata["vae"] = vae
+
+        self.write_user_metadata(name, user_metadata)
+
+    def put_values_into_components(self, name):
+        user_metadata = self.get_user_metadata(name)
+        values = super().put_values_into_components(name)
+
+        return [
+            *values[0:5],
+            user_metadata.get('vae', ''),
+        ]
+
+    def create_editor(self):
+        self.create_default_editor_elems()
+
+        with gr.Row():
+            self.select_vae = gr.Dropdown(choices=["Automatic", "None"] + list(sd_vae.vae_dict), value="None", label="Preferred VAE", elem_id="checpoint_edit_user_metadata_preferred_vae")
+            create_refresh_button(self.select_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, "checpoint_edit_user_metadata_refresh_preferred_vae")
+
+        self.edit_notes = gr.TextArea(label='Notes', lines=4)
+
+        self.create_default_buttons()
+
+        viewed_components = [
+            self.edit_name,
+            self.edit_description,
+            self.html_filedata,
+            self.html_preview,
+            self.edit_notes,
+            self.select_vae,
+        ]
+
+        self.button_edit\
+            .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\
+            .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])
+
+        edited_components = [
+            self.edit_description,
+            self.edit_notes,
+            self.select_vae,
+        ]
+
+        self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components)

+ 1 - 1
modules/ui_extra_networks_hypernets.py

@@ -11,7 +11,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
     def refresh(self):
         shared.reload_hypernetworks()
 
-    def create_item(self, name, index=None):
+    def create_item(self, name, index=None, enable_filter=True):
         full_path = shared.hypernetworks[name]
         path, ext = os.path.splitext(full_path)
 

+ 1 - 1
modules/ui_extra_networks_textual_inversion.py

@@ -12,7 +12,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
     def refresh(self):
         sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
 
-    def create_item(self, name, index=None):
+    def create_item(self, name, index=None, enable_filter=True):
         embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
 
         path, ext = os.path.splitext(embedding.filename)

+ 1 - 1
modules/ui_postprocessing.py

@@ -6,7 +6,7 @@ import modules.generation_parameters_copypaste as parameters_copypaste
 def create_ui():
     tab_index = gr.State(value=0)
 
-    with gr.Row().style(equal_height=False, variant='compact'):
+    with gr.Row(equal_height=False, variant='compact'):
         with gr.Column(variant='compact'):
             with gr.Tabs(elem_id="mode_extras"):
                 with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:

+ 110 - 0
modules/ui_prompt_styles.py

@@ -0,0 +1,110 @@
+import gradio as gr
+
+from modules import shared, ui_common, ui_components, styles
+
+styles_edit_symbol = '\U0001f58c\uFE0F'  # 🖌️
+styles_materialize_symbol = '\U0001f4cb'  # 📋
+
+
+def select_style(name):
+    style = shared.prompt_styles.styles.get(name)
+    existing = style is not None
+    empty = not name
+
+    prompt = style.prompt if style else gr.update()
+    negative_prompt = style.negative_prompt if style else gr.update()
+
+    return prompt, negative_prompt, gr.update(visible=existing), gr.update(visible=not empty)
+
+
+def save_style(name, prompt, negative_prompt):
+    if not name:
+        return gr.update(visible=False)
+
+    style = styles.PromptStyle(name, prompt, negative_prompt)
+    shared.prompt_styles.styles[style.name] = style
+    shared.prompt_styles.save_styles(shared.styles_filename)
+
+    return gr.update(visible=True)
+
+
+def delete_style(name):
+    if name == "":
+        return
+
+    shared.prompt_styles.styles.pop(name, None)
+    shared.prompt_styles.save_styles(shared.styles_filename)
+
+    return '', '', ''
+
+
+def materialize_styles(prompt, negative_prompt, styles):
+    prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
+    negative_prompt = shared.prompt_styles.apply_negative_styles_to_prompt(negative_prompt, styles)
+
+    return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=negative_prompt), gr.Dropdown.update(value=[])]
+
+
+def refresh_styles():
+    return gr.update(choices=list(shared.prompt_styles.styles)), gr.update(choices=list(shared.prompt_styles.styles))
+
+
+class UiPromptStyles:
+    def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt):
+        self.tabname = tabname
+
+        with gr.Row(elem_id=f"{tabname}_styles_row"):
+            self.dropdown = gr.Dropdown(label="Styles", show_label=False, elem_id=f"{tabname}_styles", choices=list(shared.prompt_styles.styles), value=[], multiselect=True, tooltip="Styles")
+            edit_button = ui_components.ToolButton(value=styles_edit_symbol, elem_id=f"{tabname}_styles_edit_button", tooltip="Edit styles")
+
+        with gr.Box(elem_id=f"{tabname}_styles_dialog", elem_classes="popup-dialog") as styles_dialog:
+            with gr.Row():
+                self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.")
+                ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles")
+                self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.")
+
+            with gr.Row():
+                self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3)
+
+            with gr.Row():
+                self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3)
+
+            with gr.Row():
+                self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False)
+                self.delete = gr.Button('Delete', variant='primary', elem_id=f'{tabname}_edit_style_delete', visible=False)
+                self.close = gr.Button('Close', variant='secondary', elem_id=f'{tabname}_edit_style_close')
+
+        self.selection.change(
+            fn=select_style,
+            inputs=[self.selection],
+            outputs=[self.prompt, self.neg_prompt, self.delete, self.save],
+            show_progress=False,
+        )
+
+        self.save.click(
+            fn=save_style,
+            inputs=[self.selection, self.prompt, self.neg_prompt],
+            outputs=[self.delete],
+            show_progress=False,
+        ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)
+
+        self.delete.click(
+            fn=delete_style,
+            _js='function(name){ if(name == "") return ""; return confirm("Delete style " + name + "?") ? name : ""; }',
+            inputs=[self.selection],
+            outputs=[self.selection, self.prompt, self.neg_prompt],
+            show_progress=False,
+        ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)
+
+        self.materialize.click(
+            fn=materialize_styles,
+            inputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
+            outputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
+            show_progress=False,
+        ).then(fn=None, _js="function(){update_"+tabname+"_tokens(); closePopup();}", show_progress=False)
+
+        ui_common.setup_dialog(button_show=edit_button, dialog=styles_dialog, button_close=self.close)
+
+
+
+

+ 1 - 1
modules/ui_settings.py

@@ -158,7 +158,7 @@ class UiSettings:
                     loadsave.create_ui()
 
                 with gr.TabItem("Sysinfo", id="sysinfo", elem_id="settings_tab_sysinfo"):
-                    gr.HTML('<a href="./internal/sysinfo-download" class="sysinfo_big_link" download>Download system info</a><br /><a href="./internal/sysinfo">(or open as text in a new page)</a>', elem_id="sysinfo_download")
+                    gr.HTML('<a href="./internal/sysinfo-download" class="sysinfo_big_link" download>Download system info</a><br /><a href="./internal/sysinfo" target="_blank">(or open as text in a new page)</a>', elem_id="sysinfo_download")
 
                     with gr.Row():
                         with gr.Column(scale=1):

+ 2 - 2
requirements.txt

@@ -7,7 +7,7 @@ blendmodes
 clean-fid
 einops
 gfpgan
-gradio==3.32.0
+gradio==3.39.0
 inflection
 jsonmerge
 kornia
@@ -30,4 +30,4 @@ tomesd
 torch
 torchdiffeq
 torchsde
-transformers==4.25.1
+transformers==4.30.2

+ 7 - 7
requirements_versions.txt

@@ -1,13 +1,13 @@
-GitPython==3.1.30
+GitPython==3.1.32
 Pillow==9.5.0
-accelerate==0.18.0
+accelerate==0.21.0
 basicsr==1.4.2
 blendmodes==2022
 clean-fid==0.1.35
 einops==0.4.1
 fastapi==0.94.0
 gfpgan==1.3.8
-gradio==3.32.0
+gradio==3.39.0
 httpcore==0.15
 inflection==0.5.1
 jsonmerge==1.8.0
@@ -22,10 +22,10 @@ pytorch_lightning==1.9.4
 realesrgan==0.3.0
 resize-right==0.0.2
 safetensors==0.3.1
-scikit-image==0.20.0
-timm==0.6.7
-tomesd==0.1.2
+scikit-image==0.21.0
+timm==0.9.2
+tomesd==0.1.3
 torch
 torchdiffeq==0.2.3
 torchsde==0.2.5
-transformers==4.25.1
+transformers==4.30.2

+ 3 - 10
scripts/xyz_grid.py

@@ -67,14 +67,6 @@ def apply_order(p, x, xs):
     p.prompt = prompt_tmp + p.prompt
 
 
-def apply_sampler(p, x, xs):
-    sampler_name = sd_samplers.samplers_map.get(x.lower(), None)
-    if sampler_name is None:
-        raise RuntimeError(f"Unknown sampler: {x}")
-
-    p.sampler_name = sampler_name
-
-
 def confirm_samplers(p, xs):
     for x in xs:
         if x.lower() not in sd_samplers.samplers_map:
@@ -224,8 +216,9 @@ axis_options = [
     AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")),
     AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),
     AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
-    AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
-    AxisOptionImg2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
+    AxisOptionTxt2Img("Sampler", str, apply_field("sampler_name"), format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
+    AxisOptionTxt2Img("Hires sampler", str, apply_field("hr_sampler_name"), confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
+    AxisOptionImg2Img("Sampler", str, apply_field("sampler_name"), format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
     AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_remove_path, confirm=confirm_checkpoints, cost=1.0, choices=lambda: sorted(sd_models.checkpoints_list, key=str.casefold)),
     AxisOption("Negative Guidance minimum sigma", float, apply_field("s_min_uncond")),
     AxisOption("Sigma Churn", float, apply_field("s_churn")),

+ 27 - 2
style.css

@@ -8,6 +8,7 @@
     --checkbox-label-gap: 0.25em 0.1em;
     --section-header-text-size: 12pt;
     --block-background-fill: transparent;
+
 }
 
 .block.padded:not(.gradio-accordion) {
@@ -42,7 +43,8 @@ div.form{
 .block.gradio-radio,
 .block.gradio-checkboxgroup,
 .block.gradio-number,
-.block.gradio-colorpicker
+.block.gradio-colorpicker,
+div.gradio-group
 {
     border-width: 0 !important;
     box-shadow: none !important;
@@ -133,6 +135,15 @@ a{
     cursor: pointer;
 }
 
+div.styler{
+    border: none;
+    background: var(--background-fill-primary);
+}
+
+.block.gradio-textbox{
+    overflow: visible !important;
+}
+
 
 /* general styled components */
 
@@ -164,7 +175,7 @@ a{
 .checkboxes-row > div{
     flex: 0;
     white-space: nowrap;
-    min-width: auto;
+    min-width: auto !important;
 }
 
 button.custom-button{
@@ -388,6 +399,7 @@ div#extras_scale_to_tab div.form{
 #quicksettings > div, #quicksettings > fieldset{
     max-width: 24em;
     min-width: 24em;
+    width: 24em;
     padding: 0;
     border: none;
     box-shadow: none;
@@ -972,3 +984,16 @@ div.block.gradio-box.edit-user-metadata {
 .edit-user-metadata-buttons{
     margin-top: 1.5em;
 }
+
+
+
+
+div.block.gradio-box.popup-dialog, .popup-dialog {
+    width: 56em;
+    background: var(--body-background-fill);
+    padding: 2em !important;
+}
+
+div.block.gradio-box.popup-dialog > div:last-child, .popup-dialog > div:last-child{
+    margin-top: 1em;
+}

+ 8 - 35
webui.py

@@ -14,7 +14,6 @@ from typing import Iterable
 from fastapi import FastAPI
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.gzip import GZipMiddleware
-from packaging import version
 
 import logging
 
@@ -50,6 +49,7 @@ startup_timer.record("setup paths")
 import ldm.modules.encoders.modules  # noqa: F401
 startup_timer.record("import ldm")
 
+
 from modules import extra_networks
 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock  # noqa: F401
 
@@ -58,10 +58,15 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__:
     torch.__long_version__ = torch.__version__
     torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
 
-from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
+from modules import shared
+
+if not shared.cmd_opts.skip_version_check:
+    errors.check_versions()
+
 import modules.codeformer_model as codeformer
-import modules.face_restoration
 import modules.gfpgan_model as gfpgan
+from modules import sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
+import modules.face_restoration
 import modules.img2img
 
 import modules.lowvram
@@ -130,37 +135,6 @@ def fix_asyncio_event_loop_policy():
     asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
 
 
-def check_versions():
-    if shared.cmd_opts.skip_version_check:
-        return
-
-    expected_torch_version = "2.0.0"
-
-    if version.parse(torch.__version__) < version.parse(expected_torch_version):
-        errors.print_error_explanation(f"""
-You are running torch {torch.__version__}.
-The program is tested to work with torch {expected_torch_version}.
-To reinstall the desired version, run with commandline flag --reinstall-torch.
-Beware that this will cause a lot of large files to be downloaded, as well as
-there are reports of issues with training tab on the latest version.
-
-Use --skip-version-check commandline argument to disable this check.
-        """.strip())
-
-    expected_xformers_version = "0.0.20"
-    if shared.xformers_available:
-        import xformers
-
-        if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
-            errors.print_error_explanation(f"""
-You are running xformers {xformers.__version__}.
-The program is tested to work with xformers {expected_xformers_version}.
-To reinstall the desired version, run with commandline flag --reinstall-xformers.
-
-Use --skip-version-check commandline argument to disable this check.
-            """.strip())
-
-
 def restore_config_state_file():
     config_state_file = shared.opts.restore_config_state_file
     if config_state_file == "":
@@ -248,7 +222,6 @@ def initialize():
     fix_asyncio_event_loop_policy()
     validate_tls_options()
     configure_sigint_handler()
-    check_versions()
     modelloader.cleanup_models()
     configure_opts_onchange()