Эх сурвалжийг харах

Merge branch 'dev' into auro-autolaunch

w-e-w 2 жил өмнө
parent
commit
f17c8c2eff
58 өөрчлөгдсөн 1575 нэмэгдсэн , 677 устгасан
  1. 4 0
      .eslintrc.js
  2. 2 2
      README.md
  3. 1 1
      extensions-builtin/Lora/ui_edit_user_metadata.py
  4. 34 10
      extensions-builtin/extra-options-section/scripts/extra_options_section.py
  5. 33 6
      javascript/extraNetworks.js
  6. 5 0
      javascript/imageviewer.js
  7. 26 0
      javascript/localStorage.js
  8. 5 5
      javascript/localization.js
  9. 8 16
      javascript/ui.js
  10. 2 0
      modules/cmd_args.py
  11. 75 8
      modules/devices.py
  12. 50 0
      modules/errors.py
  13. 7 3
      modules/extensions.py
  14. 19 0
      modules/extra_networks.py
  15. 20 0
      modules/generation_parameters_copypaste.py
  16. 60 0
      modules/gradio_extensons.py
  17. 2 3
      modules/hypernetworks/hypernetwork.py
  18. 1 1
      modules/images.py
  19. 2 4
      modules/img2img.py
  20. 26 3
      modules/launch_utils.py
  21. 3 0
      modules/lowvram.py
  22. 116 72
      modules/processing.py
  23. 7 2
      modules/prompt_parser.py
  24. 102 0
      modules/rng_philox.py
  25. 0 60
      modules/scripts.py
  26. 10 6
      modules/sd_hijack.py
  27. 2 0
      modules/sd_hijack_clip.py
  28. 0 2
      modules/sd_hijack_inpainting.py
  29. 2 2
      modules/sd_hijack_optimizations.py
  30. 131 37
      modules/sd_models.py
  31. 4 4
      modules/sd_models_xl.py
  32. 45 11
      modules/sd_samplers_common.py
  33. 38 5
      modules/sd_samplers_kdiffusion.py
  34. 66 13
      modules/sd_vae.py
  35. 1 1
      modules/sd_vae_approx.py
  36. 44 8
      modules/sd_vae_taesd.py
  37. 129 45
      modules/shared.py
  38. 1 4
      modules/styles.py
  39. 3 1
      modules/textual_inversion/textual_inversion.py
  40. 2 1
      modules/txt2img.py
  41. 160 213
      modules/ui.py
  42. 1 1
      modules/ui_checkpoint_merger.py
  43. 29 5
      modules/ui_common.py
  44. 1 1
      modules/ui_components.py
  45. 15 11
      modules/ui_extensions.py
  46. 30 44
      modules/ui_extra_networks.py
  47. 4 1
      modules/ui_extra_networks_checkpoints.py
  48. 66 0
      modules/ui_extra_networks_checkpoints_user_metadata.py
  49. 1 1
      modules/ui_extra_networks_hypernets.py
  50. 1 1
      modules/ui_extra_networks_textual_inversion.py
  51. 1 1
      modules/ui_postprocessing.py
  52. 110 0
      modules/ui_prompt_styles.py
  53. 1 1
      modules/ui_settings.py
  54. 2 2
      requirements.txt
  55. 7 7
      requirements_versions.txt
  56. 3 10
      scripts/xyz_grid.py
  57. 45 6
      style.css
  58. 10 36
      webui.py

+ 4 - 0
.eslintrc.js

@@ -87,5 +87,9 @@ module.exports = {
         modalNextImage: "readonly",
         modalNextImage: "readonly",
         // token-counters.js
         // token-counters.js
         setupTokenCounters: "readonly",
         setupTokenCounters: "readonly",
+        // localStorage.js
+        localSet: "readonly",
+        localGet: "readonly",
+        localRemove: "readonly"
     }
     }
 };
 };

+ 2 - 2
README.md

@@ -115,7 +115,7 @@ Alternatively, use online services (like Google Colab):
 1. Install the dependencies:
 1. Install the dependencies:
 ```bash
 ```bash
 # Debian-based:
 # Debian-based:
-sudo apt install wget git python3 python3-venv
+sudo apt install wget git python3 python3-venv libgl1 libglib2.0-0
 # Red Hat-based:
 # Red Hat-based:
 sudo dnf install wget git python3
 sudo dnf install wget git python3
 # Arch-based:
 # Arch-based:
@@ -123,7 +123,7 @@ sudo pacman -S wget git python3
 ```
 ```
 2. Navigate to the directory you would like the webui to be installed and execute the following command:
 2. Navigate to the directory you would like the webui to be installed and execute the following command:
 ```bash
 ```bash
-bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh)
+wget -q https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh
 ```
 ```
 3. Run `webui.sh`.
 3. Run `webui.sh`.
 4. Check `webui-user.sh` for options.
 4. Check `webui-user.sh` for options.

+ 1 - 1
extensions-builtin/Lora/ui_edit_user_metadata.py

@@ -167,7 +167,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
                 random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
                 random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
 
 
             with gr.Column(scale=1, min_width=120):
             with gr.Column(scale=1, min_width=120):
-                generate_random_prompt = gr.Button('Generate').style(full_width=True, size="lg")
+                generate_random_prompt = gr.Button('Generate', size="lg", scale=1)
 
 
         self.edit_notes = gr.TextArea(label='Notes', lines=4)
         self.edit_notes = gr.TextArea(label='Notes', lines=4)
 
 

+ 34 - 10
extensions-builtin/extra-options-section/scripts/extra_options_section.py

@@ -1,5 +1,7 @@
+import math
+
 import gradio as gr
 import gradio as gr
-from modules import scripts, shared, ui_components, ui_settings
+from modules import scripts, shared, ui_components, ui_settings, generation_parameters_copypaste
 from modules.ui_components import FormColumn
 from modules.ui_components import FormColumn
 
 
 
 
@@ -19,18 +21,37 @@ class ExtraOptionsSection(scripts.Script):
     def ui(self, is_img2img):
     def ui(self, is_img2img):
         self.comps = []
         self.comps = []
         self.setting_names = []
         self.setting_names = []
+        self.infotext_fields = []
+
+        mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping}
 
 
         with gr.Blocks() as interface:
         with gr.Blocks() as interface:
-            with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group(), gr.Row():
-                for setting_name in shared.opts.extra_options:
-                    with FormColumn():
-                        comp = ui_settings.create_setting_component(setting_name)
+            with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group():
+
+                row_count = math.ceil(len(shared.opts.extra_options) / shared.opts.extra_options_cols)
+
+                for row in range(row_count):
+                    with gr.Row():
+                        for col in range(shared.opts.extra_options_cols):
+                            index = row * shared.opts.extra_options_cols + col
+                            if index >= len(shared.opts.extra_options):
+                                break
+
+                            setting_name = shared.opts.extra_options[index]
 
 
-                    self.comps.append(comp)
-                    self.setting_names.append(setting_name)
+                            with FormColumn():
+                                comp = ui_settings.create_setting_component(setting_name)
+
+                            self.comps.append(comp)
+                            self.setting_names.append(setting_name)
+
+                            setting_infotext_name = mapping.get(setting_name)
+                            if setting_infotext_name is not None:
+                                self.infotext_fields.append((comp, setting_infotext_name))
 
 
         def get_settings_values():
         def get_settings_values():
-            return [ui_settings.get_value_for_setting(key) for key in self.setting_names]
+            res = [ui_settings.get_value_for_setting(key) for key in self.setting_names]
+            return res[0] if len(res) == 1 else res
 
 
         interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False)
         interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False)
 
 
@@ -43,6 +64,9 @@ class ExtraOptionsSection(scripts.Script):
 
 
 
 
 shared.options_templates.update(shared.options_section(('ui', "User interface"), {
 shared.options_templates.update(shared.options_section(('ui', "User interface"), {
-    "extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_restart(),
-    "extra_options_accordion": shared.OptionInfo(False, "Place options in main UI into an accordion")
+    "extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_reload_ui(),
+    "extra_options_cols": shared.OptionInfo(1, "Options in main UI - number of columns", gr.Number, {"precision": 0}).needs_reload_ui(),
+    "extra_options_accordion": shared.OptionInfo(False, "Options in main UI - place into an accordion").needs_reload_ui()
 }))
 }))
+
+

+ 33 - 6
javascript/extraNetworks.js

@@ -1,20 +1,38 @@
+function toggleCss(key, css, enable) {
+    var style = document.getElementById(key);
+    if (enable && !style) {
+        style = document.createElement('style');
+        style.id = key;
+        style.type = 'text/css';
+        document.head.appendChild(style);
+    }
+    if (style && !enable) {
+        document.head.removeChild(style);
+    }
+    if (style) {
+        style.innerHTML == '';
+        style.appendChild(document.createTextNode(css));
+    }
+}
+
 function setupExtraNetworksForTab(tabname) {
 function setupExtraNetworksForTab(tabname) {
     gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks');
     gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks');
 
 
     var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div');
     var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div');
-    var search = gradioApp().querySelector('#' + tabname + '_extra_search textarea');
+    var searchDiv = gradioApp().getElementById(tabname + '_extra_search');
+    var search = searchDiv.querySelector('textarea');
     var sort = gradioApp().getElementById(tabname + '_extra_sort');
     var sort = gradioApp().getElementById(tabname + '_extra_sort');
     var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder');
     var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder');
     var refresh = gradioApp().getElementById(tabname + '_extra_refresh');
     var refresh = gradioApp().getElementById(tabname + '_extra_refresh');
+    var showDirsDiv = gradioApp().getElementById(tabname + '_extra_show_dirs');
+    var showDirs = gradioApp().querySelector('#' + tabname + '_extra_show_dirs input');
 
 
-    search.classList.add('search');
-    sort.classList.add('sort');
-    sortOrder.classList.add('sortorder');
     sort.dataset.sortkey = 'sortDefault';
     sort.dataset.sortkey = 'sortDefault';
-    tabs.appendChild(search);
+    tabs.appendChild(searchDiv);
     tabs.appendChild(sort);
     tabs.appendChild(sort);
     tabs.appendChild(sortOrder);
     tabs.appendChild(sortOrder);
     tabs.appendChild(refresh);
     tabs.appendChild(refresh);
+    tabs.appendChild(showDirsDiv);
 
 
     var applyFilter = function() {
     var applyFilter = function() {
         var searchTerm = search.value.toLowerCase();
         var searchTerm = search.value.toLowerCase();
@@ -80,6 +98,15 @@ function setupExtraNetworksForTab(tabname) {
     });
     });
 
 
     extraNetworksApplyFilter[tabname] = applyFilter;
     extraNetworksApplyFilter[tabname] = applyFilter;
+
+    var showDirsUpdate = function() {
+        var css = '#' + tabname + '_extra_tabs .extra-network-subdirs { display: none; }';
+        toggleCss(tabname + '_extra_show_dirs_style', css, !showDirs.checked);
+        localSet('extra-networks-show-dirs', showDirs.checked ? 1 : 0);
+    };
+    showDirs.checked = localGet('extra-networks-show-dirs', 1) == 1;
+    showDirs.addEventListener("change", showDirsUpdate);
+    showDirsUpdate();
 }
 }
 
 
 function applyExtraNetworkFilter(tabname) {
 function applyExtraNetworkFilter(tabname) {
@@ -179,7 +206,7 @@ function saveCardPreview(event, tabname, filename) {
 }
 }
 
 
 function extraNetworksSearchButton(tabs_id, event) {
 function extraNetworksSearchButton(tabs_id, event) {
-    var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea');
+    var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > label > textarea');
     var button = event.target;
     var button = event.target;
     var text = button.classList.contains("search-all") ? "" : button.textContent.trim();
     var text = button.classList.contains("search-all") ? "" : button.textContent.trim();
 
 

+ 5 - 0
javascript/imageviewer.js

@@ -136,6 +136,11 @@ function setupImageForLightbox(e) {
     var event = isFirefox ? 'mousedown' : 'click';
     var event = isFirefox ? 'mousedown' : 'click';
 
 
     e.addEventListener(event, function(evt) {
     e.addEventListener(event, function(evt) {
+        if (evt.button == 1) {
+            open(evt.target.src);
+            evt.preventDefault();
+            return;
+        }
         if (!opts.js_modal_lightbox || evt.button != 0) return;
         if (!opts.js_modal_lightbox || evt.button != 0) return;
 
 
         modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed);
         modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed);

+ 26 - 0
javascript/localStorage.js

@@ -0,0 +1,26 @@
+
+function localSet(k, v) {
+    try {
+        localStorage.setItem(k, v);
+    } catch (e) {
+        console.warn(`Failed to save ${k} to localStorage: ${e}`);
+    }
+}
+
+function localGet(k, def) {
+    try {
+        return localStorage.getItem(k);
+    } catch (e) {
+        console.warn(`Failed to load ${k} from localStorage: ${e}`);
+    }
+
+    return def;
+}
+
+function localRemove(k) {
+    try {
+        return localStorage.removeItem(k);
+    } catch (e) {
+        console.warn(`Failed to remove ${k} from localStorage: ${e}`);
+    }
+}

+ 5 - 5
javascript/localization.js

@@ -11,11 +11,11 @@ var ignore_ids_for_localization = {
     train_hypernetwork: 'OPTION',
     train_hypernetwork: 'OPTION',
     txt2img_styles: 'OPTION',
     txt2img_styles: 'OPTION',
     img2img_styles: 'OPTION',
     img2img_styles: 'OPTION',
-    setting_random_artist_categories: 'SPAN',
-    setting_face_restoration_model: 'SPAN',
-    setting_realesrgan_enabled_models: 'SPAN',
-    extras_upscaler_1: 'SPAN',
-    extras_upscaler_2: 'SPAN',
+    setting_random_artist_categories: 'OPTION',
+    setting_face_restoration_model: 'OPTION',
+    setting_realesrgan_enabled_models: 'OPTION',
+    extras_upscaler_1: 'OPTION',
+    extras_upscaler_2: 'OPTION',
 };
 };
 
 
 var re_num = /^[.\d]+$/;
 var re_num = /^[.\d]+$/;

+ 8 - 16
javascript/ui.js

@@ -152,15 +152,11 @@ function submit() {
     showSubmitButtons('txt2img', false);
     showSubmitButtons('txt2img', false);
 
 
     var id = randomId();
     var id = randomId();
-    try {
-        localStorage.setItem("txt2img_task_id", id);
-    } catch (e) {
-        console.warn(`Failed to save txt2img task id to localStorage: ${e}`);
-    }
+    localSet("txt2img_task_id", id);
 
 
     requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
     requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
         showSubmitButtons('txt2img', true);
         showSubmitButtons('txt2img', true);
-        localStorage.removeItem("txt2img_task_id");
+        localRemove("txt2img_task_id");
         showRestoreProgressButton('txt2img', false);
         showRestoreProgressButton('txt2img', false);
     });
     });
 
 
@@ -175,15 +171,11 @@ function submit_img2img() {
     showSubmitButtons('img2img', false);
     showSubmitButtons('img2img', false);
 
 
     var id = randomId();
     var id = randomId();
-    try {
-        localStorage.setItem("img2img_task_id", id);
-    } catch (e) {
-        console.warn(`Failed to save img2img task id to localStorage: ${e}`);
-    }
+    localSet("img2img_task_id", id);
 
 
     requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {
     requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {
         showSubmitButtons('img2img', true);
         showSubmitButtons('img2img', true);
-        localStorage.removeItem("img2img_task_id");
+        localRemove("img2img_task_id");
         showRestoreProgressButton('img2img', false);
         showRestoreProgressButton('img2img', false);
     });
     });
 
 
@@ -197,7 +189,7 @@ function submit_img2img() {
 
 
 function restoreProgressTxt2img() {
 function restoreProgressTxt2img() {
     showRestoreProgressButton("txt2img", false);
     showRestoreProgressButton("txt2img", false);
-    var id = localStorage.getItem("txt2img_task_id");
+    var id = localGet("txt2img_task_id");
 
 
     if (id) {
     if (id) {
         requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
         requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
@@ -211,7 +203,7 @@ function restoreProgressTxt2img() {
 function restoreProgressImg2img() {
 function restoreProgressImg2img() {
     showRestoreProgressButton("img2img", false);
     showRestoreProgressButton("img2img", false);
 
 
-    var id = localStorage.getItem("img2img_task_id");
+    var id = localGet("img2img_task_id");
 
 
     if (id) {
     if (id) {
         requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {
         requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {
@@ -224,8 +216,8 @@ function restoreProgressImg2img() {
 
 
 
 
 onUiLoaded(function() {
 onUiLoaded(function() {
-    showRestoreProgressButton('txt2img', localStorage.getItem("txt2img_task_id"));
-    showRestoreProgressButton('img2img', localStorage.getItem("img2img_task_id"));
+    showRestoreProgressButton('txt2img', localGet("txt2img_task_id"));
+    showRestoreProgressButton('img2img', localGet("img2img_task_id"));
 });
 });
 
 
 
 

+ 2 - 0
modules/cmd_args.py

@@ -112,3 +112,5 @@ parser.add_argument('--subpath', type=str, help='customize the subpath for gradi
 parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
 parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
 parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
 parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
 parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
 parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
+parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
+parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False)

+ 75 - 8
modules/devices.py

@@ -3,7 +3,7 @@ import contextlib
 from functools import lru_cache
 from functools import lru_cache
 
 
 import torch
 import torch
-from modules import errors
+from modules import errors, rng_philox
 
 
 if sys.platform == "darwin":
 if sys.platform == "darwin":
     from modules import mac_specific
     from modules import mac_specific
@@ -71,14 +71,17 @@ def enable_tf32():
         torch.backends.cudnn.allow_tf32 = True
         torch.backends.cudnn.allow_tf32 = True
 
 
 
 
-
 errors.run(enable_tf32, "Enabling TF32")
 errors.run(enable_tf32, "Enabling TF32")
 
 
-cpu = torch.device("cpu")
-device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
-dtype = torch.float16
-dtype_vae = torch.float16
-dtype_unet = torch.float16
+cpu: torch.device = torch.device("cpu")
+device: torch.device = None
+device_interrogate: torch.device = None
+device_gfpgan: torch.device = None
+device_esrgan: torch.device = None
+device_codeformer: torch.device = None
+dtype: torch.dtype = torch.float16
+dtype_vae: torch.dtype = torch.float16
+dtype_unet: torch.dtype = torch.float16
 unet_needs_upcast = False
 unet_needs_upcast = False
 
 
 
 
@@ -90,23 +93,87 @@ def cond_cast_float(input):
     return input.float() if unet_needs_upcast else input
     return input.float() if unet_needs_upcast else input
 
 
 
 
+nv_rng = None
+
+
 def randn(seed, shape):
 def randn(seed, shape):
+    """Generate a tensor with random numbers from a normal distribution using seed.
+
+    Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
+
     from modules.shared import opts
     from modules.shared import opts
 
 
-    torch.manual_seed(seed)
+    manual_seed(seed)
+
+    if opts.randn_source == "NV":
+        return torch.asarray(nv_rng.randn(shape), device=device)
+
     if opts.randn_source == "CPU" or device.type == 'mps':
     if opts.randn_source == "CPU" or device.type == 'mps':
         return torch.randn(shape, device=cpu).to(device)
         return torch.randn(shape, device=cpu).to(device)
+
     return torch.randn(shape, device=device)
     return torch.randn(shape, device=device)
 
 
 
 
+def randn_local(seed, shape):
+    """Generate a tensor with random numbers from a normal distribution using seed.
+
+    Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
+
+    from modules.shared import opts
+
+    if opts.randn_source == "NV":
+        rng = rng_philox.Generator(seed)
+        return torch.asarray(rng.randn(shape), device=device)
+
+    local_device = cpu if opts.randn_source == "CPU" or device.type == 'mps' else device
+    local_generator = torch.Generator(local_device).manual_seed(int(seed))
+    return torch.randn(shape, device=local_device, generator=local_generator).to(device)
+
+
+def randn_like(x):
+    """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
+
+    Use either randn() or manual_seed() to initialize the generator."""
+
+    from modules.shared import opts
+
+    if opts.randn_source == "NV":
+        return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
+
+    if opts.randn_source == "CPU" or x.device.type == 'mps':
+        return torch.randn_like(x, device=cpu).to(x.device)
+
+    return torch.randn_like(x)
+
+
 def randn_without_seed(shape):
 def randn_without_seed(shape):
+    """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
+
+    Use either randn() or manual_seed() to initialize the generator."""
+
     from modules.shared import opts
     from modules.shared import opts
 
 
+    if opts.randn_source == "NV":
+        return torch.asarray(nv_rng.randn(shape), device=device)
+
     if opts.randn_source == "CPU" or device.type == 'mps':
     if opts.randn_source == "CPU" or device.type == 'mps':
         return torch.randn(shape, device=cpu).to(device)
         return torch.randn(shape, device=cpu).to(device)
+
     return torch.randn(shape, device=device)
     return torch.randn(shape, device=device)
 
 
 
 
+def manual_seed(seed):
+    """Set up a global random number generator using the specified seed."""
+    from modules.shared import opts
+
+    if opts.randn_source == "NV":
+        global nv_rng
+        nv_rng = rng_philox.Generator(seed)
+        return
+
+    torch.manual_seed(seed)
+
+
 def autocast(disable=False):
 def autocast(disable=False):
     from modules import shared
     from modules import shared
 
 

+ 50 - 0
modules/errors.py

@@ -84,3 +84,53 @@ def run(code, task):
         code()
         code()
     except Exception as e:
     except Exception as e:
         display(task, e)
         display(task, e)
+
+
+def check_versions():
+    from packaging import version
+    from modules import shared
+
+    import torch
+    import gradio
+
+    expected_torch_version = "2.0.0"
+    expected_xformers_version = "0.0.20"
+    expected_gradio_version = "3.39.0"
+
+    if version.parse(torch.__version__) < version.parse(expected_torch_version):
+        print_error_explanation(f"""
+You are running torch {torch.__version__}.
+The program is tested to work with torch {expected_torch_version}.
+To reinstall the desired version, run with commandline flag --reinstall-torch.
+Beware that this will cause a lot of large files to be downloaded, as well as
+there are reports of issues with training tab on the latest version.
+
+Use --skip-version-check commandline argument to disable this check.
+        """.strip())
+
+    if shared.xformers_available:
+        import xformers
+
+        if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
+            print_error_explanation(f"""
+You are running xformers {xformers.__version__}.
+The program is tested to work with xformers {expected_xformers_version}.
+To reinstall the desired version, run with commandline flag --reinstall-xformers.
+
+Use --skip-version-check commandline argument to disable this check.
+            """.strip())
+
+    if gradio.__version__ != expected_gradio_version:
+        print_error_explanation(f"""
+You are running gradio {gradio.__version__}.
+The program is designed to work with gradio {expected_gradio_version}.
+Using a different version of gradio is extremely likely to break the program.
+
+Reasons why you have the mismatched gradio version can be:
+  - you use --skip-install flag.
+  - you use webui.py to start the program instead of launch.py.
+  - an extension installs the incompatible gradio version.
+
+Use --skip-version-check commandline argument to disable this check.
+        """.strip())
+

+ 7 - 3
modules/extensions.py

@@ -11,9 +11,9 @@ os.makedirs(extensions_dir, exist_ok=True)
 
 
 
 
 def active():
 def active():
-    if shared.opts.disable_all_extensions == "all":
+    if shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all":
         return []
         return []
-    elif shared.opts.disable_all_extensions == "extra":
+    elif shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions == "extra":
         return [x for x in extensions if x.enabled and x.is_builtin]
         return [x for x in extensions if x.enabled and x.is_builtin]
     else:
     else:
         return [x for x in extensions if x.enabled]
         return [x for x in extensions if x.enabled]
@@ -141,8 +141,12 @@ def list_extensions():
     if not os.path.isdir(extensions_dir):
     if not os.path.isdir(extensions_dir):
         return
         return
 
 
-    if shared.opts.disable_all_extensions == "all":
+    if shared.cmd_opts.disable_all_extensions:
+        print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
+    elif shared.opts.disable_all_extensions == "all":
         print("*** \"Disable all extensions\" option was set, will not load any extensions ***")
         print("*** \"Disable all extensions\" option was set, will not load any extensions ***")
+    elif shared.cmd_opts.disable_extra_extensions:
+        print("*** \"--disable-extra-extensions\" arg was used, will only load built-in extensions ***")
     elif shared.opts.disable_all_extensions == "extra":
     elif shared.opts.disable_all_extensions == "extra":
         print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
         print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
 
 

+ 19 - 0
modules/extra_networks.py

@@ -1,3 +1,5 @@
+import json
+import os
 import re
 import re
 from collections import defaultdict
 from collections import defaultdict
 
 
@@ -177,3 +179,20 @@ def parse_prompts(prompts):
 
 
     return res, extra_data
     return res, extra_data
 
 
+
+def get_user_metadata(filename):
+    if filename is None:
+        return {}
+
+    basename, ext = os.path.splitext(filename)
+    metadata_filename = basename + '.json'
+
+    metadata = {}
+    try:
+        if os.path.isfile(metadata_filename):
+            with open(metadata_filename, "r", encoding="utf8") as file:
+                metadata = json.load(file)
+    except Exception as e:
+        errors.display(e, f"reading extra network user metadata from {metadata_filename}")
+
+    return metadata

+ 20 - 0
modules/generation_parameters_copypaste.py

@@ -280,6 +280,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
     if "Hires sampler" not in res:
     if "Hires sampler" not in res:
         res["Hires sampler"] = "Use same sampler"
         res["Hires sampler"] = "Use same sampler"
 
 
+    if "Hires checkpoint" not in res:
+        res["Hires checkpoint"] = "Use same checkpoint"
+
     if "Hires prompt" not in res:
     if "Hires prompt" not in res:
         res["Hires prompt"] = ""
         res["Hires prompt"] = ""
 
 
@@ -304,6 +307,12 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
     if "Schedule rho" not in res:
     if "Schedule rho" not in res:
         res["Schedule rho"] = 0
         res["Schedule rho"] = 0
 
 
+    if "VAE Encoder" not in res:
+        res["VAE Encoder"] = "Full"
+
+    if "VAE Decoder" not in res:
+        res["VAE Decoder"] = "Full"
+
     return res
     return res
 
 
 
 
@@ -319,6 +328,10 @@ infotext_to_setting_name_mapping = [
     ('Noise multiplier', 'initial_noise_multiplier'),
     ('Noise multiplier', 'initial_noise_multiplier'),
     ('Eta', 'eta_ancestral'),
     ('Eta', 'eta_ancestral'),
     ('Eta DDIM', 'eta_ddim'),
     ('Eta DDIM', 'eta_ddim'),
+    ('Sigma churn', 's_churn'),
+    ('Sigma tmin', 's_tmin'),
+    ('Sigma tmax', 's_tmax'),
+    ('Sigma noise', 's_noise'),
     ('Discard penultimate sigma', 'always_discard_next_to_last_sigma'),
     ('Discard penultimate sigma', 'always_discard_next_to_last_sigma'),
     ('UniPC variant', 'uni_pc_variant'),
     ('UniPC variant', 'uni_pc_variant'),
     ('UniPC skip type', 'uni_pc_skip_type'),
     ('UniPC skip type', 'uni_pc_skip_type'),
@@ -329,6 +342,8 @@ infotext_to_setting_name_mapping = [
     ('RNG', 'randn_source'),
     ('RNG', 'randn_source'),
     ('NGMS', 's_min_uncond'),
     ('NGMS', 's_min_uncond'),
     ('Pad conds', 'pad_cond_uncond'),
     ('Pad conds', 'pad_cond_uncond'),
+    ('VAE Encoder', 'sd_vae_encode_method'),
+    ('VAE Decoder', 'sd_vae_decode_method'),
 ]
 ]
 
 
 
 
@@ -399,10 +414,15 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
         return res
         return res
 
 
     if override_settings_component is not None:
     if override_settings_component is not None:
+        already_handled_fields = {key: 1 for _, key in paste_fields}
+
         def paste_settings(params):
         def paste_settings(params):
             vals = {}
             vals = {}
 
 
             for param_name, setting_name in infotext_to_setting_name_mapping:
             for param_name, setting_name in infotext_to_setting_name_mapping:
+                if param_name in already_handled_fields:
+                    continue
+
                 v = params.get(param_name, None)
                 v = params.get(param_name, None)
                 if v is None:
                 if v is None:
                     continue
                     continue

+ 60 - 0
modules/gradio_extensons.py

@@ -0,0 +1,60 @@
+import gradio as gr
+
+from modules import scripts
+
+def add_classes_to_gradio_component(comp):
+    """
+    this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
+    """
+
+    comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
+
+    if getattr(comp, 'multiselect', False):
+        comp.elem_classes.append('multiselect')
+
+
+def IOComponent_init(self, *args, **kwargs):
+    self.webui_tooltip = kwargs.pop('tooltip', None)
+
+    if scripts.scripts_current is not None:
+        scripts.scripts_current.before_component(self, **kwargs)
+
+    scripts.script_callbacks.before_component_callback(self, **kwargs)
+
+    res = original_IOComponent_init(self, *args, **kwargs)
+
+    add_classes_to_gradio_component(self)
+
+    scripts.script_callbacks.after_component_callback(self, **kwargs)
+
+    if scripts.scripts_current is not None:
+        scripts.scripts_current.after_component(self, **kwargs)
+
+    return res
+
+
+def Block_get_config(self):
+    config = original_Block_get_config(self)
+
+    webui_tooltip = getattr(self, 'webui_tooltip', None)
+    if webui_tooltip:
+        config["webui_tooltip"] = webui_tooltip
+
+    return config
+
+
+def BlockContext_init(self, *args, **kwargs):
+    res = original_BlockContext_init(self, *args, **kwargs)
+
+    add_classes_to_gradio_component(self)
+
+    return res
+
+
+original_IOComponent_init = gr.components.IOComponent.__init__
+original_Block_get_config = gr.blocks.Block.get_config
+original_BlockContext_init = gr.blocks.BlockContext.__init__
+
+gr.components.IOComponent.__init__ = IOComponent_init
+gr.blocks.Block.get_config = Block_get_config
+gr.blocks.BlockContext.__init__ = BlockContext_init

+ 2 - 3
modules/hypernetworks/hypernetwork.py

@@ -10,7 +10,7 @@ import torch
 import tqdm
 import tqdm
 from einops import rearrange, repeat
 from einops import rearrange, repeat
 from ldm.util import default
 from ldm.util import default
-from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
+from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
 from modules.textual_inversion import textual_inversion, logging
 from modules.textual_inversion import textual_inversion, logging
 from modules.textual_inversion.learn_schedule import LearnRateScheduler
 from modules.textual_inversion.learn_schedule import LearnRateScheduler
 from torch import einsum
 from torch import einsum
@@ -469,8 +469,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
 
 
 
 
 def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
 def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
-    # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
-    from modules import images
+    from modules import images, processing
 
 
     save_hypernetwork_every = save_hypernetwork_every or 0
     save_hypernetwork_every = save_hypernetwork_every or 0
     create_image_every = create_image_every or 0
     create_image_every = create_image_every or 0

+ 1 - 1
modules/images.py

@@ -318,7 +318,7 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None):
     return res
     return res
 
 
 
 
-invalid_filename_chars = '<>:"/\\|?*\n'
+invalid_filename_chars = '<>:"/\\|?*\n\r\t'
 invalid_filename_prefix = ' '
 invalid_filename_prefix = ' '
 invalid_filename_postfix = ' .'
 invalid_filename_postfix = ' .'
 re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
 re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')

+ 2 - 4
modules/img2img.py

@@ -3,7 +3,7 @@ from contextlib import closing
 from pathlib import Path
 from pathlib import Path
 
 
 import numpy as np
 import numpy as np
-from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
+from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
 import gradio as gr
 import gradio as gr
 
 
 from modules import sd_samplers, images as imgutil
 from modules import sd_samplers, images as imgutil
@@ -129,9 +129,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
         mask = None
         mask = None
     elif mode == 2:  # inpaint
     elif mode == 2:  # inpaint
         image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
         image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
-        alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
-        mask = mask.convert('L').point(lambda x: 255 if x > 128 else 0, mode='1')
-        mask = ImageChops.lighter(alpha_mask, mask).convert('L')
+        mask = mask.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
         image = image.convert("RGB")
         image = image.convert("RGB")
     elif mode == 3:  # inpaint sketch
     elif mode == 3:  # inpaint sketch
         image = inpaint_color_sketch
         image = inpaint_color_sketch

+ 26 - 3
modules/launch_utils.py

@@ -139,6 +139,27 @@ def check_run_python(code: str) -> bool:
     return result.returncode == 0
     return result.returncode == 0
 
 
 
 
+def git_fix_workspace(dir, name):
+    run(f'"{git}" -C "{dir}" fetch --refetch --no-auto-gc', f"Fetching all contents for {name}", f"Couldn't fetch {name}", live=True)
+    run(f'"{git}" -C "{dir}" gc --aggressive --prune=now', f"Pruning {name}", f"Couldn't prune {name}", live=True)
+    return
+
+
+def run_git(dir, name, command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live, autofix=True):
+    try:
+        return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)
+    except RuntimeError:
+        pass
+
+    if not autofix:
+        return None
+
+    print(f"{errdesc}, attempting autofix...")
+    git_fix_workspace(dir, name)
+
+    return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)
+
+
 def git_clone(url, dir, name, commithash=None):
 def git_clone(url, dir, name, commithash=None):
     # TODO clone into temporary dir and move if successful
     # TODO clone into temporary dir and move if successful
 
 
@@ -146,12 +167,14 @@ def git_clone(url, dir, name, commithash=None):
         if commithash is None:
         if commithash is None:
             return
             return
 
 
-        current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
+        current_hash = run_git(dir, name, 'rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
         if current_hash == commithash:
         if current_hash == commithash:
             return
             return
 
 
-        run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
-        run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
+        run_git('fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
+
+        run_git('checkout', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
+
         return
         return
 
 
     run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
     run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)

+ 3 - 0
modules/lowvram.py

@@ -15,6 +15,9 @@ def send_everything_to_cpu():
 
 
 
 
 def setup_for_low_vram(sd_model, use_medvram):
 def setup_for_low_vram(sd_model, use_medvram):
+    if getattr(sd_model, 'lowvram', False):
+        return
+
     sd_model.lowvram = True
     sd_model.lowvram = True
 
 
     parents = {}
     parents = {}

+ 116 - 72
modules/processing.py

@@ -16,6 +16,7 @@ from typing import Any, Dict, List
 import modules.sd_hijack
 import modules.sd_hijack
 from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors
 from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors
 from modules.sd_hijack import model_hijack
 from modules.sd_hijack import model_hijack
+from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
 from modules.shared import opts, cmd_opts, state
 from modules.shared import opts, cmd_opts, state
 import modules.shared as shared
 import modules.shared as shared
 import modules.paths as paths
 import modules.paths as paths
@@ -83,7 +84,7 @@ def txt2img_image_conditioning(sd_model, x, width, height):
 
 
         # The "masked-image" in this case will just be all zeros since the entire image is masked.
         # The "masked-image" in this case will just be all zeros since the entire image is masked.
         image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
         image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
-        image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
+        image_conditioning = images_tensor_to_samples(image_conditioning, approximation_indexes.get(opts.sd_vae_encode_method))
 
 
         # Add the fake full 1s mask to the first dimension.
         # Add the fake full 1s mask to the first dimension.
         image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
         image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
@@ -109,7 +110,7 @@ class StableDiffusionProcessing:
     cached_uc = [None, None]
     cached_uc = [None, None]
     cached_c = [None, None]
     cached_c = [None, None]
 
 
-    def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
+    def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = None, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
         if sampler_index is not None:
         if sampler_index is not None:
             print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
             print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
 
 
@@ -147,8 +148,8 @@ class StableDiffusionProcessing:
         self.s_min_uncond = s_min_uncond or opts.s_min_uncond
         self.s_min_uncond = s_min_uncond or opts.s_min_uncond
         self.s_churn = s_churn or opts.s_churn
         self.s_churn = s_churn or opts.s_churn
         self.s_tmin = s_tmin or opts.s_tmin
         self.s_tmin = s_tmin or opts.s_tmin
-        self.s_tmax = s_tmax or float('inf')  # not representable as a standard ui option
-        self.s_noise = s_noise or opts.s_noise
+        self.s_tmax = (s_tmax if s_tmax is not None else opts.s_tmax) or float('inf')
+        self.s_noise = s_noise if s_noise is not None else opts.s_noise
         self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
         self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
         self.override_settings_restore_afterwards = override_settings_restore_afterwards
         self.override_settings_restore_afterwards = override_settings_restore_afterwards
         self.is_using_inpainting_conditioning = False
         self.is_using_inpainting_conditioning = False
@@ -202,7 +203,7 @@ class StableDiffusionProcessing:
         midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
         midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
         midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
         midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
 
 
-        conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
+        conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
         conditioning = torch.nn.functional.interpolate(
         conditioning = torch.nn.functional.interpolate(
             self.sd_model.depth_model(midas_in),
             self.sd_model.depth_model(midas_in),
             size=conditioning_image.shape[2:],
             size=conditioning_image.shape[2:],
@@ -215,7 +216,7 @@ class StableDiffusionProcessing:
         return conditioning
         return conditioning
 
 
     def edit_image_conditioning(self, source_image):
     def edit_image_conditioning(self, source_image):
-        conditioning_image = self.sd_model.encode_first_stage(source_image).mode()
+        conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
 
 
         return conditioning_image
         return conditioning_image
 
 
@@ -294,7 +295,7 @@ class StableDiffusionProcessing:
         self.sampler = None
         self.sampler = None
         self.c = None
         self.c = None
         self.uc = None
         self.uc = None
-        if not opts.experimental_persistent_cond_cache:
+        if not opts.persistent_cond_cache:
             StableDiffusionProcessing.cached_c = [None, None]
             StableDiffusionProcessing.cached_c = [None, None]
             StableDiffusionProcessing.cached_uc = [None, None]
             StableDiffusionProcessing.cached_uc = [None, None]
 
 
@@ -318,6 +319,21 @@ class StableDiffusionProcessing:
         self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
         self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
         self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
         self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
 
 
+    def cached_params(self, required_prompts, steps, extra_network_data):
+        """Returns parameters that invalidate the cond cache if changed"""
+
+        return (
+            required_prompts,
+            steps,
+            opts.CLIP_stop_at_last_layers,
+            shared.sd_model.sd_checkpoint_info,
+            extra_network_data,
+            opts.sdxl_crop_left,
+            opts.sdxl_crop_top,
+            self.width,
+            self.height,
+        )
+
     def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data):
     def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data):
         """
         """
         Returns the result of calling function(shared.sd_model, required_prompts, steps)
         Returns the result of calling function(shared.sd_model, required_prompts, steps)
@@ -331,17 +347,7 @@ class StableDiffusionProcessing:
         caches is a list with items described above.
         caches is a list with items described above.
         """
         """
 
 
-        cached_params = (
-            required_prompts,
-            steps,
-            opts.CLIP_stop_at_last_layers,
-            shared.sd_model.sd_checkpoint_info,
-            extra_network_data,
-            opts.sdxl_crop_left,
-            opts.sdxl_crop_top,
-            self.width,
-            self.height,
-        )
+        cached_params = self.cached_params(required_prompts, steps, extra_network_data)
 
 
         for cache in caches:
         for cache in caches:
             if cache[0] is not None and cached_params == cache[0]:
             if cache[0] is not None and cached_params == cache[0]:
@@ -367,6 +373,10 @@ class StableDiffusionProcessing:
     def parse_extra_network_prompts(self):
     def parse_extra_network_prompts(self):
         self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
         self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
 
 
+    def save_samples(self) -> bool:
+        """Returns whether generated images need to be written to disk"""
+        return opts.samples_save and not self.do_not_save_samples and (opts.save_incomplete_images or not state.interrupted and not state.skipped)
+
 
 
 class Processed:
 class Processed:
     def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
     def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
@@ -492,7 +502,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
         noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
         noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
 
 
         subnoise = None
         subnoise = None
-        if subseeds is not None:
+        if subseeds is not None and subseed_strength != 0:
             subseed = 0 if i >= len(subseeds) else subseeds[i]
             subseed = 0 if i >= len(subseeds) else subseeds[i]
 
 
             subnoise = devices.randn(subseed, noise_shape)
             subnoise = devices.randn(subseed, noise_shape)
@@ -524,7 +534,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
             cnt = p.sampler.number_of_needed_noises(p)
             cnt = p.sampler.number_of_needed_noises(p)
 
 
             if eta_noise_seed_delta > 0:
             if eta_noise_seed_delta > 0:
-                torch.manual_seed(seed + eta_noise_seed_delta)
+                devices.manual_seed(seed + eta_noise_seed_delta)
 
 
             for j in range(cnt):
             for j in range(cnt):
                 sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
                 sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
@@ -538,8 +548,12 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
     return x
     return x
 
 
 
 
+class DecodedSamples(list):
+    already_decoded = True
+
+
 def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
 def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
-    samples = []
+    samples = DecodedSamples()
 
 
     for i in range(batch.shape[0]):
     for i in range(batch.shape[0]):
         sample = decode_first_stage(model, batch[i:i + 1])[0]
         sample = decode_first_stage(model, batch[i:i + 1])[0]
@@ -572,12 +586,6 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
     return samples
     return samples
 
 
 
 
-def decode_first_stage(model, x):
-    x = model.decode_first_stage(x.to(devices.dtype_vae))
-
-    return x
-
-
 def get_fixed_seed(seed):
 def get_fixed_seed(seed):
     if seed is None or seed == '' or seed == -1:
     if seed is None or seed == '' or seed == -1:
         return int(random.randrange(4294967294))
         return int(random.randrange(4294967294))
@@ -636,7 +644,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
         "Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
         "Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
         "Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
         "Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
         "Init image hash": getattr(p, 'init_img_hash', None),
         "Init image hash": getattr(p, 'init_img_hash', None),
-        "RNG": opts.randn_source if opts.randn_source != "GPU" else None,
+        "RNG": opts.randn_source if opts.randn_source != "GPU" and opts.randn_source != "NV" else None,
         "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
         "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
         **p.extra_generation_params,
         **p.extra_generation_params,
         "Version": program_version() if opts.add_version_to_infotext else None,
         "Version": program_version() if opts.add_version_to_infotext else None,
@@ -793,7 +801,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
             with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
             with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
                 samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
                 samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
 
 
-            x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
+            if getattr(samples_ddim, 'already_decoded', False):
+                x_samples_ddim = samples_ddim
+            else:
+                if opts.sd_vae_decode_method != 'Full':
+                    p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
+
+                x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
+
             x_samples_ddim = torch.stack(x_samples_ddim).float()
             x_samples_ddim = torch.stack(x_samples_ddim).float()
             x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
             x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
 
 
@@ -817,6 +832,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
             def infotext(index=0, use_main_prompt=False):
             def infotext(index=0, use_main_prompt=False):
                 return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts)
                 return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts)
 
 
+            save_samples = p.save_samples()
+
             for i, x_sample in enumerate(x_samples_ddim):
             for i, x_sample in enumerate(x_samples_ddim):
                 p.batch_index = i
                 p.batch_index = i
 
 
@@ -824,7 +841,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
                 x_sample = x_sample.astype(np.uint8)
                 x_sample = x_sample.astype(np.uint8)
 
 
                 if p.restore_faces:
                 if p.restore_faces:
-                    if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
+                    if save_samples and opts.save_images_before_face_restoration:
                         images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-face-restoration")
                         images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-face-restoration")
 
 
                     devices.torch_gc()
                     devices.torch_gc()
@@ -838,16 +855,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
                     pp = scripts.PostprocessImageArgs(image)
                     pp = scripts.PostprocessImageArgs(image)
                     p.scripts.postprocess_image(p, pp)
                     p.scripts.postprocess_image(p, pp)
                     image = pp.image
                     image = pp.image
-
                 if p.color_corrections is not None and i < len(p.color_corrections):
                 if p.color_corrections is not None and i < len(p.color_corrections):
-                    if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
+                    if save_samples and opts.save_images_before_color_correction:
                         image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
                         image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
                         images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
                         images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
                     image = apply_color_correction(p.color_corrections[i], image)
                     image = apply_color_correction(p.color_corrections[i], image)
 
 
                 image = apply_overlay(image, p.paste_to, i, p.overlay_images)
                 image = apply_overlay(image, p.paste_to, i, p.overlay_images)
 
 
-                if opts.samples_save and not p.do_not_save_samples:
+                if save_samples:
                     images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
                     images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
 
 
                 text = infotext(i)
                 text = infotext(i)
@@ -855,8 +871,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
                 if opts.enable_pnginfo:
                 if opts.enable_pnginfo:
                     image.info["parameters"] = text
                     image.info["parameters"] = text
                 output_images.append(image)
                 output_images.append(image)
-
-                if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
+                if save_samples and hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
                     image_mask = p.mask_for_overlay.convert('RGB')
                     image_mask = p.mask_for_overlay.convert('RGB')
                     image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
                     image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
 
 
@@ -892,7 +907,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
                     grid.info["parameters"] = text
                     grid.info["parameters"] = text
                 output_images.insert(0, grid)
                 output_images.insert(0, grid)
                 index_of_first_image = 1
                 index_of_first_image = 1
-
             if opts.grid_save:
             if opts.grid_save:
                 images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
                 images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
 
 
@@ -935,7 +949,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
     cached_hr_uc = [None, None]
     cached_hr_uc = [None, None]
     cached_hr_c = [None, None]
     cached_hr_c = [None, None]
 
 
-    def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
+    def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_checkpoint_name: str = None, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
         super().__init__(**kwargs)
         super().__init__(**kwargs)
         self.enable_hr = enable_hr
         self.enable_hr = enable_hr
         self.denoising_strength = denoising_strength
         self.denoising_strength = denoising_strength
@@ -946,11 +960,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
         self.hr_resize_y = hr_resize_y
         self.hr_resize_y = hr_resize_y
         self.hr_upscale_to_x = hr_resize_x
         self.hr_upscale_to_x = hr_resize_x
         self.hr_upscale_to_y = hr_resize_y
         self.hr_upscale_to_y = hr_resize_y
+        self.hr_checkpoint_name = hr_checkpoint_name
+        self.hr_checkpoint_info = None
         self.hr_sampler_name = hr_sampler_name
         self.hr_sampler_name = hr_sampler_name
         self.hr_prompt = hr_prompt
         self.hr_prompt = hr_prompt
         self.hr_negative_prompt = hr_negative_prompt
         self.hr_negative_prompt = hr_negative_prompt
         self.all_hr_prompts = None
         self.all_hr_prompts = None
         self.all_hr_negative_prompts = None
         self.all_hr_negative_prompts = None
+        self.latent_scale_mode = None
 
 
         if firstphase_width != 0 or firstphase_height != 0:
         if firstphase_width != 0 or firstphase_height != 0:
             self.hr_upscale_to_x = self.width
             self.hr_upscale_to_x = self.width
@@ -973,6 +990,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
 
 
     def init(self, all_prompts, all_seeds, all_subseeds):
     def init(self, all_prompts, all_seeds, all_subseeds):
         if self.enable_hr:
         if self.enable_hr:
+            if self.hr_checkpoint_name:
+                self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
+
+                if self.hr_checkpoint_info is None:
+                    raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}')
+
+                self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title
+
             if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
             if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
                 self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
                 self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
 
 
@@ -982,6 +1007,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
             if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
             if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
                 self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
                 self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
 
 
+            self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
+            if self.enable_hr and self.latent_scale_mode is None:
+                if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
+                    raise Exception(f"could not find upscaler named {self.hr_upscaler}")
+
             if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
             if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
                 self.hr_resize_x = self.width
                 self.hr_resize_x = self.width
                 self.hr_resize_y = self.height
                 self.hr_resize_y = self.height
@@ -1020,14 +1050,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
                     self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
                     self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
                     self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
                     self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
 
 
-            # special case: the user has chosen to do nothing
-            if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height:
-                self.enable_hr = False
-                self.denoising_strength = None
-                self.extra_generation_params.pop("Hires upscale", None)
-                self.extra_generation_params.pop("Hires resize", None)
-                return
-
             if not state.processing_has_refined_job_count:
             if not state.processing_has_refined_job_count:
                 if state.job_count == -1:
                 if state.job_count == -1:
                     state.job_count = self.n_iter
                     state.job_count = self.n_iter
@@ -1045,17 +1067,32 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
     def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
     def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
         self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
         self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
 
 
-        latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
-        if self.enable_hr and latent_scale_mode is None:
-            if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
-                raise Exception(f"could not find upscaler named {self.hr_upscaler}")
-
         x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
         x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
         samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
         samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
+        del x
 
 
         if not self.enable_hr:
         if not self.enable_hr:
             return samples
             return samples
 
 
+        if self.latent_scale_mode is None:
+            decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
+        else:
+            decoded_samples = None
+
+        current = shared.sd_model.sd_checkpoint_info
+        try:
+            if self.hr_checkpoint_info is not None:
+                self.sampler = None
+                sd_models.reload_model_weights(info=self.hr_checkpoint_info)
+                devices.torch_gc()
+
+            return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
+        finally:
+            self.sampler = None
+            sd_models.reload_model_weights(info=current)
+            devices.torch_gc()
+
+    def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
         self.is_hr_pass = True
         self.is_hr_pass = True
 
 
         target_width = self.hr_upscale_to_x
         target_width = self.hr_upscale_to_x
@@ -1064,7 +1101,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
         def save_intermediate(image, index):
         def save_intermediate(image, index):
             """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
             """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
 
 
-            if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
+            if not self.save_samples() or not opts.save_images_before_highres_fix:
                 return
                 return
 
 
             if not isinstance(image, Image.Image):
             if not isinstance(image, Image.Image):
@@ -1073,11 +1110,18 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
             info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
             info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
             images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
             images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
 
 
-        if latent_scale_mode is not None:
+        img2img_sampler_name = self.hr_sampler_name or self.sampler_name
+
+        if self.sampler_name in ['PLMS', 'UniPC']:  # PLMS/UniPC do not support img2img so we just silently switch to DDIM
+            img2img_sampler_name = 'DDIM'
+
+        self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
+
+        if self.latent_scale_mode is not None:
             for i in range(samples.shape[0]):
             for i in range(samples.shape[0]):
                 save_intermediate(samples, i)
                 save_intermediate(samples, i)
 
 
-            samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
+            samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=self.latent_scale_mode["mode"], antialias=self.latent_scale_mode["antialias"])
 
 
             # Avoid making the inpainting conditioning unless necessary as
             # Avoid making the inpainting conditioning unless necessary as
             # this does need some extra compute to decode / encode the image again.
             # this does need some extra compute to decode / encode the image again.
@@ -1086,7 +1130,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
             else:
             else:
                 image_conditioning = self.txt2img_image_conditioning(samples)
                 image_conditioning = self.txt2img_image_conditioning(samples)
         else:
         else:
-            decoded_samples = decode_first_stage(self.sd_model, samples)
             lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
             lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
 
 
             batch_images = []
             batch_images = []
@@ -1103,28 +1146,21 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
                 batch_images.append(image)
                 batch_images.append(image)
 
 
             decoded_samples = torch.from_numpy(np.array(batch_images))
             decoded_samples = torch.from_numpy(np.array(batch_images))
-            decoded_samples = decoded_samples.to(shared.device)
-            decoded_samples = 2. * decoded_samples - 1.
+            decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)
 
 
-            samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
+            if opts.sd_vae_encode_method != 'Full':
+                self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
+            samples = images_tensor_to_samples(decoded_samples, approximation_indexes.get(opts.sd_vae_encode_method))
 
 
             image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
             image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
 
 
         shared.state.nextjob()
         shared.state.nextjob()
 
 
-        img2img_sampler_name = self.hr_sampler_name or self.sampler_name
-
-        if self.sampler_name in ['PLMS', 'UniPC']:  # PLMS/UniPC do not support img2img so we just silently switch to DDIM
-            img2img_sampler_name = 'DDIM'
-
-        self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
-
         samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
         samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
 
 
         noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
         noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
 
 
         # GC now before running the next img2img to prevent running out of memory
         # GC now before running the next img2img to prevent running out of memory
-        x = None
         devices.torch_gc()
         devices.torch_gc()
 
 
         if not self.disable_extra_networks:
         if not self.disable_extra_networks:
@@ -1143,15 +1179,17 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
 
 
         sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
         sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
 
 
+        decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
+
         self.is_hr_pass = False
         self.is_hr_pass = False
 
 
-        return samples
+        return decoded_samples
 
 
     def close(self):
     def close(self):
         super().close()
         super().close()
         self.hr_c = None
         self.hr_c = None
         self.hr_uc = None
         self.hr_uc = None
-        if not opts.experimental_persistent_cond_cache:
+        if not opts.persistent_cond_cache:
             StableDiffusionProcessingTxt2Img.cached_hr_uc = [None, None]
             StableDiffusionProcessingTxt2Img.cached_hr_uc = [None, None]
             StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None]
             StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None]
 
 
@@ -1184,8 +1222,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
         if self.hr_c is not None:
         if self.hr_c is not None:
             return
             return
 
 
-        self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
-        self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
+        hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
+        hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)
+
+        self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
+        self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
 
 
     def setup_conds(self):
     def setup_conds(self):
         super().setup_conds()
         super().setup_conds()
@@ -1193,7 +1234,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
         self.hr_uc = None
         self.hr_uc = None
         self.hr_c = None
         self.hr_c = None
 
 
-        if self.enable_hr:
+        if self.enable_hr and self.hr_checkpoint_info is None:
             if shared.opts.hires_fix_use_firstpass_conds:
             if shared.opts.hires_fix_use_firstpass_conds:
                 self.calculate_hr_conds()
                 self.calculate_hr_conds()
 
 
@@ -1344,10 +1385,13 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
             raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
             raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
 
 
         image = torch.from_numpy(batch_images)
         image = torch.from_numpy(batch_images)
-        image = 2. * image - 1.
         image = image.to(shared.device, dtype=devices.dtype_vae)
         image = image.to(shared.device, dtype=devices.dtype_vae)
 
 
-        self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
+        if opts.sd_vae_encode_method != 'Full':
+            self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
+
+        self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
+        devices.torch_gc()
 
 
         if self.resize_mode == 3:
         if self.resize_mode == 3:
             self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
             self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")

+ 7 - 2
modules/prompt_parser.py

@@ -20,7 +20,7 @@ prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
         | "(" prompt ":" prompt ")"
         | "(" prompt ":" prompt ")"
         | "[" prompt "]"
         | "[" prompt "]"
 scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER [WHITESPACE] "]"
 scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER [WHITESPACE] "]"
-alternate: "[" prompt ("|" prompt)+ "]"
+alternate: "[" prompt ("|" [prompt])+ "]"
 WHITESPACE: /\s+/
 WHITESPACE: /\s+/
 plain: /([^\\\[\]():|]|\\.)+/
 plain: /([^\\\[\]():|]|\\.)+/
 %import common.SIGNED_NUMBER -> NUMBER
 %import common.SIGNED_NUMBER -> NUMBER
@@ -53,6 +53,10 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
     [[3, '((a][:b:c '], [10, '((a][:b:c d']]
     [[3, '((a][:b:c '], [10, '((a][:b:c d']]
     >>> g("[a|(b:1.1)]")
     >>> g("[a|(b:1.1)]")
     [[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
     [[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
+    >>> g("[fe|]male")
+    [[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']]
+    >>> g("[fe|||]male")
+    [[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']]
     """
     """
 
 
     def collect_steps(steps, tree):
     def collect_steps(steps, tree):
@@ -78,7 +82,8 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
                 before, after, _, when, _ = args
                 before, after, _, when, _ = args
                 yield before or () if step <= when else after
                 yield before or () if step <= when else after
             def alternate(self, args):
             def alternate(self, args):
-                yield next(args[(step - 1)%len(args)])
+                args = ["" if not arg else arg for arg in args]
+                yield args[(step - 1) % len(args)]
             def start(self, args):
             def start(self, args):
                 def flatten(x):
                 def flatten(x):
                     if type(x) == str:
                     if type(x) == str:

+ 102 - 0
modules/rng_philox.py

@@ -0,0 +1,102 @@
+"""RNG imitiating torch cuda randn on CPU. You are welcome.
+
+Usage:
+
+```
+g = Generator(seed=0)
+print(g.randn(shape=(3, 4)))
+```
+
+Expected output:
+```
+[[-0.92466259 -0.42534415 -2.6438457   0.14518388]
+ [-0.12086647 -0.57972564 -0.62285122 -0.32838709]
+ [-1.07454231 -0.36314407 -1.67105067  2.26550497]]
+```
+"""
+
+import numpy as np
+
+philox_m = [0xD2511F53, 0xCD9E8D57]
+philox_w = [0x9E3779B9, 0xBB67AE85]
+
+two_pow32_inv = np.array([2.3283064e-10], dtype=np.float32)
+two_pow32_inv_2pi = np.array([2.3283064e-10 * 6.2831855], dtype=np.float32)
+
+
+def uint32(x):
+    """Converts (N,) np.uint64 array into (2, N) np.unit32 array."""
+    return x.view(np.uint32).reshape(-1, 2).transpose(1, 0)
+
+
+def philox4_round(counter, key):
+    """A single round of the Philox 4x32 random number generator."""
+
+    v1 = uint32(counter[0].astype(np.uint64) * philox_m[0])
+    v2 = uint32(counter[2].astype(np.uint64) * philox_m[1])
+
+    counter[0] = v2[1] ^ counter[1] ^ key[0]
+    counter[1] = v2[0]
+    counter[2] = v1[1] ^ counter[3] ^ key[1]
+    counter[3] = v1[0]
+
+
+def philox4_32(counter, key, rounds=10):
+    """Generates 32-bit random numbers using the Philox 4x32 random number generator.
+
+    Parameters:
+        counter (numpy.ndarray): A 4xN array of 32-bit integers representing the counter values (offset into generation).
+        key (numpy.ndarray): A 2xN array of 32-bit integers representing the key values (seed).
+        rounds (int): The number of rounds to perform.
+
+    Returns:
+        numpy.ndarray: A 4xN array of 32-bit integers containing the generated random numbers.
+    """
+
+    for _ in range(rounds - 1):
+        philox4_round(counter, key)
+
+        key[0] = key[0] + philox_w[0]
+        key[1] = key[1] + philox_w[1]
+
+    philox4_round(counter, key)
+    return counter
+
+
+def box_muller(x, y):
+    """Returns just the first out of two numbers generated by Box–Muller transform algorithm."""
+    u = x * two_pow32_inv + two_pow32_inv / 2
+    v = y * two_pow32_inv_2pi + two_pow32_inv_2pi / 2
+
+    s = np.sqrt(-2.0 * np.log(u))
+
+    r1 = s * np.sin(v)
+    return r1.astype(np.float32)
+
+
+class Generator:
+    """RNG that produces same outputs as torch.randn(..., device='cuda') on CPU"""
+
+    def __init__(self, seed):
+        self.seed = seed
+        self.offset = 0
+
+    def randn(self, shape):
+        """Generate a sequence of n standard normal random variables using the Philox 4x32 random number generator and the Box-Muller transform."""
+
+        n = 1
+        for x in shape:
+            n *= x
+
+        counter = np.zeros((4, n), dtype=np.uint32)
+        counter[0] = self.offset
+        counter[2] = np.arange(n, dtype=np.uint32)  # up to 2^32 numbers can be generated - if you want more you'd need to spill into counter[3]
+        self.offset += 1
+
+        key = np.empty(n, dtype=np.uint64)
+        key.fill(self.seed)
+        key = uint32(key)
+
+        g = philox4_32(counter, key)
+
+        return box_muller(g[0], g[1]).reshape(shape)  # discard g[2] and g[3]

+ 0 - 60
modules/scripts.py

@@ -631,63 +631,3 @@ def reload_script_body_only():
 
 
 
 
 reload_scripts = load_scripts  # compatibility alias
 reload_scripts = load_scripts  # compatibility alias
-
-
-def add_classes_to_gradio_component(comp):
-    """
-    this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
-    """
-
-    comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
-
-    if getattr(comp, 'multiselect', False):
-        comp.elem_classes.append('multiselect')
-
-
-
-def IOComponent_init(self, *args, **kwargs):
-    self.webui_tooltip = kwargs.pop('tooltip', None)
-
-    if scripts_current is not None:
-        scripts_current.before_component(self, **kwargs)
-
-    script_callbacks.before_component_callback(self, **kwargs)
-
-    res = original_IOComponent_init(self, *args, **kwargs)
-
-    add_classes_to_gradio_component(self)
-
-    script_callbacks.after_component_callback(self, **kwargs)
-
-    if scripts_current is not None:
-        scripts_current.after_component(self, **kwargs)
-
-    return res
-
-
-def Block_get_config(self):
-    config = original_Block_get_config(self)
-
-    webui_tooltip = getattr(self, 'webui_tooltip', None)
-    if webui_tooltip:
-        config["webui_tooltip"] = webui_tooltip
-
-    return config
-
-
-original_IOComponent_init = gr.components.IOComponent.__init__
-original_Block_get_config = gr.components.Block.get_config
-gr.components.IOComponent.__init__ = IOComponent_init
-gr.components.Block.get_config = Block_get_config
-
-
-def BlockContext_init(self, *args, **kwargs):
-    res = original_BlockContext_init(self, *args, **kwargs)
-
-    add_classes_to_gradio_component(self)
-
-    return res
-
-
-original_BlockContext_init = gr.blocks.BlockContext.__init__
-gr.blocks.BlockContext.__init__ = BlockContext_init

+ 10 - 6
modules/sd_hijack.py

@@ -2,11 +2,10 @@ import torch
 from torch.nn.functional import silu
 from torch.nn.functional import silu
 from types import MethodType
 from types import MethodType
 
 
-import modules.textual_inversion.textual_inversion
 from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
 from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
 from modules.hypernetworks import hypernetwork
 from modules.hypernetworks import hypernetwork
 from modules.shared import cmd_opts
 from modules.shared import cmd_opts
-from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
+from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, sd_hijack_inpainting
 
 
 import ldm.modules.attention
 import ldm.modules.attention
 import ldm.modules.diffusionmodules.model
 import ldm.modules.diffusionmodules.model
@@ -30,8 +29,12 @@ ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.Cros
 ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
 ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
 
 
 # silence new console spam from SD2
 # silence new console spam from SD2
-ldm.modules.attention.print = lambda *args: None
-ldm.modules.diffusionmodules.model.print = lambda *args: None
+ldm.modules.attention.print = shared.ldm_print
+ldm.modules.diffusionmodules.model.print = shared.ldm_print
+ldm.util.print = shared.ldm_print
+ldm.models.diffusion.ddpm.print = shared.ldm_print
+
+sd_hijack_inpainting.do_inpainting_hijack()
 
 
 optimizers = []
 optimizers = []
 current_optimizer: sd_hijack_optimizations.SdOptimization = None
 current_optimizer: sd_hijack_optimizations.SdOptimization = None
@@ -164,12 +167,13 @@ class StableDiffusionModelHijack:
     clip = None
     clip = None
     optimization_method = None
     optimization_method = None
 
 
-    embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
-
     def __init__(self):
     def __init__(self):
+        import modules.textual_inversion.textual_inversion
+
         self.extra_generation_params = {}
         self.extra_generation_params = {}
         self.comments = []
         self.comments = []
 
 
+        self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
         self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
         self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
 
 
     def apply_optimizations(self, option=None):
     def apply_optimizations(self, option=None):

+ 2 - 0
modules/sd_hijack_clip.py

@@ -245,6 +245,8 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
                 hashes.append(f"{name}: {shorthash}")
                 hashes.append(f"{name}: {shorthash}")
 
 
             if hashes:
             if hashes:
+                if self.hijack.extra_generation_params.get("TI hashes"):
+                    hashes.append(self.hijack.extra_generation_params.get("TI hashes"))
                 self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
                 self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
 
 
         if getattr(self.wrapped, 'return_pooled', False):
         if getattr(self.wrapped, 'return_pooled', False):

+ 0 - 2
modules/sd_hijack_inpainting.py

@@ -92,6 +92,4 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
 
 
 
 
 def do_inpainting_hijack():
 def do_inpainting_hijack():
-    # p_sample_plms is needed because PLMS can't work with dicts as conditionings
-
     ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
     ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms

+ 2 - 2
modules/sd_hijack_optimizations.py

@@ -256,9 +256,9 @@ def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs):
             raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
             raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
                                f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
                                f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
 
 
-        slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
+        slice_size = q.shape[1] // steps
         for i in range(0, q.shape[1], slice_size):
         for i in range(0, q.shape[1], slice_size):
-            end = i + slice_size
+            end = min(i + slice_size, q.shape[1])
             s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
             s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
 
 
             s2 = s1.softmax(dim=-1, dtype=q.dtype)
             s2 = s1.softmax(dim=-1, dtype=q.dtype)

+ 131 - 37
modules/sd_models.py

@@ -15,7 +15,6 @@ import ldm.modules.midas as midas
 from ldm.util import instantiate_from_config
 from ldm.util import instantiate_from_config
 
 
 from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache
 from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache
-from modules.sd_hijack_inpainting import do_inpainting_hijack
 from modules.timer import Timer
 from modules.timer import Timer
 import tomesd
 import tomesd
 
 
@@ -67,8 +66,9 @@ class CheckpointInfo:
         self.shorthash = self.sha256[0:10] if self.sha256 else None
         self.shorthash = self.sha256[0:10] if self.sha256 else None
 
 
         self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
         self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
+        self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]'
 
 
-        self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
+        self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
 
 
     def register(self):
     def register(self):
         checkpoints_list[self.title] = self
         checkpoints_list[self.title] = self
@@ -87,6 +87,7 @@ class CheckpointInfo:
 
 
         checkpoints_list.pop(self.title, None)
         checkpoints_list.pop(self.title, None)
         self.title = f'{self.name} [{self.shorthash}]'
         self.title = f'{self.name} [{self.shorthash}]'
+        self.short_title = f'{self.name_for_extra} [{self.shorthash}]'
         self.register()
         self.register()
 
 
         return self.shorthash
         return self.shorthash
@@ -107,14 +108,8 @@ def setup_model():
     enable_midas_autodownload()
     enable_midas_autodownload()
 
 
 
 
-def checkpoint_tiles():
-    def convert(name):
-        return int(name) if name.isdigit() else name.lower()
-
-    def alphanumeric_key(key):
-        return [convert(c) for c in re.split('([0-9]+)', key)]
-
-    return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
+def checkpoint_tiles(use_short=False):
+    return [x.short_title if use_short else x.title for x in checkpoints_list.values()]
 
 
 
 
 def list_models():
 def list_models():
@@ -137,11 +132,14 @@ def list_models():
     elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
     elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
         print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
         print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
 
 
-    for filename in sorted(model_list, key=str.lower):
+    for filename in model_list:
         checkpoint_info = CheckpointInfo(filename)
         checkpoint_info = CheckpointInfo(filename)
         checkpoint_info.register()
         checkpoint_info.register()
 
 
 
 
+re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")
+
+
 def get_closet_checkpoint_match(search_string):
 def get_closet_checkpoint_match(search_string):
     checkpoint_info = checkpoint_aliases.get(search_string, None)
     checkpoint_info = checkpoint_aliases.get(search_string, None)
     if checkpoint_info is not None:
     if checkpoint_info is not None:
@@ -151,6 +149,11 @@ def get_closet_checkpoint_match(search_string):
     if found:
     if found:
         return found[0]
         return found[0]
 
 
+    search_string_without_checksum = re.sub(re_strip_checksum, '', search_string)
+    found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title))
+    if found:
+        return found[0]
+
     return None
     return None
 
 
 
 
@@ -303,12 +306,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
         sd_models_xl.extend_sdxl(model)
         sd_models_xl.extend_sdxl(model)
 
 
     model.load_state_dict(state_dict, strict=False)
     model.load_state_dict(state_dict, strict=False)
-    del state_dict
     timer.record("apply weights to model")
     timer.record("apply weights to model")
 
 
     if shared.opts.sd_checkpoint_cache > 0:
     if shared.opts.sd_checkpoint_cache > 0:
         # cache newly loaded model
         # cache newly loaded model
-        checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
+        checkpoints_loaded[checkpoint_info] = state_dict
+
+    del state_dict
 
 
     if shared.cmd_opts.opt_channelslast:
     if shared.cmd_opts.opt_channelslast:
         model.to(memory_format=torch.channels_last)
         model.to(memory_format=torch.channels_last)
@@ -352,7 +356,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
 
 
     sd_vae.delete_base_vae()
     sd_vae.delete_base_vae()
     sd_vae.clear_loaded_vae()
     sd_vae.clear_loaded_vae()
-    vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
+    vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple()
     sd_vae.load_vae(model, vae_file, vae_source)
     sd_vae.load_vae(model, vae_file, vae_source)
     timer.record("load VAE")
     timer.record("load VAE")
 
 
@@ -429,6 +433,7 @@ sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
 class SdModelData:
 class SdModelData:
     def __init__(self):
     def __init__(self):
         self.sd_model = None
         self.sd_model = None
+        self.loaded_sd_models = []
         self.was_loaded_at_least_once = False
         self.was_loaded_at_least_once = False
         self.lock = threading.Lock()
         self.lock = threading.Lock()
 
 
@@ -443,6 +448,7 @@ class SdModelData:
 
 
                 try:
                 try:
                     load_model()
                     load_model()
+
                 except Exception as e:
                 except Exception as e:
                     errors.display(e, "loading stable diffusion model", full_traceback=True)
                     errors.display(e, "loading stable diffusion model", full_traceback=True)
                     print("", file=sys.stderr)
                     print("", file=sys.stderr)
@@ -454,11 +460,24 @@ class SdModelData:
     def set_sd_model(self, v):
     def set_sd_model(self, v):
         self.sd_model = v
         self.sd_model = v
 
 
+        try:
+            self.loaded_sd_models.remove(v)
+        except ValueError:
+            pass
+
+        if v is not None:
+            self.loaded_sd_models.insert(0, v)
+
 
 
 model_data = SdModelData()
 model_data = SdModelData()
 
 
 
 
 def get_empty_cond(sd_model):
 def get_empty_cond(sd_model):
+    from modules import extra_networks, processing
+
+    p = processing.StableDiffusionProcessingTxt2Img()
+    extra_networks.activate(p, {})
+
     if hasattr(sd_model, 'conditioner'):
     if hasattr(sd_model, 'conditioner'):
         d = sd_model.get_learned_conditioning([""])
         d = sd_model.get_learned_conditioning([""])
         return d['crossattn']
         return d['crossattn']
@@ -466,19 +485,43 @@ def get_empty_cond(sd_model):
         return sd_model.cond_stage_model([""])
         return sd_model.cond_stage_model([""])
 
 
 
 
+def send_model_to_cpu(m):
+    from modules import lowvram
+
+    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+        lowvram.send_everything_to_cpu()
+    else:
+        m.to(devices.cpu)
+
+    devices.torch_gc()
+
+
+def send_model_to_device(m):
+    from modules import lowvram
+
+    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+        lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
+    else:
+        m.to(shared.device)
+
+
+def send_model_to_trash(m):
+    m.to(device="meta")
+    devices.torch_gc()
+
+
 def load_model(checkpoint_info=None, already_loaded_state_dict=None):
 def load_model(checkpoint_info=None, already_loaded_state_dict=None):
-    from modules import lowvram, sd_hijack
+    from modules import sd_hijack
     checkpoint_info = checkpoint_info or select_checkpoint()
     checkpoint_info = checkpoint_info or select_checkpoint()
 
 
+    timer = Timer()
+
     if model_data.sd_model:
     if model_data.sd_model:
-        sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
+        send_model_to_trash(model_data.sd_model)
         model_data.sd_model = None
         model_data.sd_model = None
-        gc.collect()
         devices.torch_gc()
         devices.torch_gc()
 
 
-    do_inpainting_hijack()
-
-    timer = Timer()
+    timer.record("unload existing model")
 
 
     if already_loaded_state_dict is not None:
     if already_loaded_state_dict is not None:
         state_dict = already_loaded_state_dict
         state_dict = already_loaded_state_dict
@@ -518,12 +561,9 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
 
 
     with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
     with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
         load_model_weights(sd_model, checkpoint_info, state_dict, timer)
         load_model_weights(sd_model, checkpoint_info, state_dict, timer)
+    timer.record("load weights from state dict")
 
 
-    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
-        lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
-    else:
-        sd_model.to(shared.device)
-
+    send_model_to_device(sd_model)
     timer.record("move model to device")
     timer.record("move model to device")
 
 
     sd_hijack.model_hijack.hijack(sd_model)
     sd_hijack.model_hijack.hijack(sd_model)
@@ -531,7 +571,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
     timer.record("hijack")
     timer.record("hijack")
 
 
     sd_model.eval()
     sd_model.eval()
-    model_data.sd_model = sd_model
+    model_data.set_sd_model(sd_model)
     model_data.was_loaded_at_least_once = True
     model_data.was_loaded_at_least_once = True
 
 
     sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)  # Reload embeddings after model load as they may or may not fit the model
     sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)  # Reload embeddings after model load as they may or may not fit the model
@@ -552,10 +592,61 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
     return sd_model
     return sd_model
 
 
 
 
+def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
+    """
+    Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models.
+    If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary).
+    If not, returns the model that can be used to load weights from checkpoint_info's file.
+    If no such model exists, returns None.
+    Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit).
+    """
+
+    already_loaded = None
+    for i in reversed(range(len(model_data.loaded_sd_models))):
+        loaded_model = model_data.loaded_sd_models[i]
+        if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename:
+            already_loaded = loaded_model
+            continue
+
+        if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0:
+            print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}")
+            model_data.loaded_sd_models.pop()
+            send_model_to_trash(loaded_model)
+            timer.record("send model to trash")
+
+        if shared.opts.sd_checkpoints_keep_in_cpu:
+            send_model_to_cpu(sd_model)
+            timer.record("send model to cpu")
+
+    if already_loaded is not None:
+        send_model_to_device(already_loaded)
+        timer.record("send model to device")
+
+        model_data.set_sd_model(already_loaded)
+        print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
+        return model_data.sd_model
+    elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
+        print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
+
+        model_data.sd_model = None
+        load_model(checkpoint_info)
+        return model_data.sd_model
+    elif len(model_data.loaded_sd_models) > 0:
+        sd_model = model_data.loaded_sd_models.pop()
+        model_data.sd_model = sd_model
+
+        print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
+        return sd_model
+    else:
+        return None
+
+
 def reload_model_weights(sd_model=None, info=None):
 def reload_model_weights(sd_model=None, info=None):
-    from modules import lowvram, devices, sd_hijack
+    from modules import devices, sd_hijack
     checkpoint_info = info or select_checkpoint()
     checkpoint_info = info or select_checkpoint()
 
 
+    timer = Timer()
+
     if not sd_model:
     if not sd_model:
         sd_model = model_data.sd_model
         sd_model = model_data.sd_model
 
 
@@ -564,19 +655,17 @@ def reload_model_weights(sd_model=None, info=None):
     else:
     else:
         current_checkpoint_info = sd_model.sd_checkpoint_info
         current_checkpoint_info = sd_model.sd_checkpoint_info
         if sd_model.sd_model_checkpoint == checkpoint_info.filename:
         if sd_model.sd_model_checkpoint == checkpoint_info.filename:
-            return
+            return sd_model
 
 
-        sd_unet.apply_unet("None")
-
-        if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
-            lowvram.send_everything_to_cpu()
-        else:
-            sd_model.to(devices.cpu)
+    sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
+    if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
+        return sd_model
 
 
+    if sd_model is not None:
+        sd_unet.apply_unet("None")
+        send_model_to_cpu(sd_model)
         sd_hijack.model_hijack.undo_hijack(sd_model)
         sd_hijack.model_hijack.undo_hijack(sd_model)
 
 
-    timer = Timer()
-
     state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
     state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
 
 
     checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
     checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
@@ -584,7 +673,9 @@ def reload_model_weights(sd_model=None, info=None):
     timer.record("find config")
     timer.record("find config")
 
 
     if sd_model is None or checkpoint_config != sd_model.used_config:
     if sd_model is None or checkpoint_config != sd_model.used_config:
-        del sd_model
+        if sd_model is not None:
+            send_model_to_trash(sd_model)
+
         load_model(checkpoint_info, already_loaded_state_dict=state_dict)
         load_model(checkpoint_info, already_loaded_state_dict=state_dict)
         return model_data.sd_model
         return model_data.sd_model
 
 
@@ -607,6 +698,9 @@ def reload_model_weights(sd_model=None, info=None):
 
 
     print(f"Weights loaded in {timer.summary()}.")
     print(f"Weights loaded in {timer.summary()}.")
 
 
+    model_data.set_sd_model(sd_model)
+    sd_unet.apply_unet()
+
     return sd_model
     return sd_model
 
 
 
 

+ 4 - 4
modules/sd_models_xl.py

@@ -98,10 +98,10 @@ def extend_sdxl(model):
     model.conditioner.wrapped = torch.nn.Module()
     model.conditioner.wrapped = torch.nn.Module()
 
 
 
 
-sgm.modules.attention.print = lambda *args: None
-sgm.modules.diffusionmodules.model.print = lambda *args: None
-sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None
-sgm.modules.encoders.modules.print = lambda *args: None
+sgm.modules.attention.print = shared.ldm_print
+sgm.modules.diffusionmodules.model.print = shared.ldm_print
+sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print
+sgm.modules.encoders.modules.print = shared.ldm_print
 
 
 # this gets the code to load the vanilla attention that we override
 # this gets the code to load the vanilla attention that we override
 sgm.modules.attention.SDP_IS_AVAILABLE = True
 sgm.modules.attention.SDP_IS_AVAILABLE = True

+ 45 - 11
modules/sd_samplers_common.py

@@ -2,10 +2,8 @@ from collections import namedtuple
 import numpy as np
 import numpy as np
 import torch
 import torch
 from PIL import Image
 from PIL import Image
-from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd
-
+from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared
 from modules.shared import opts, state
 from modules.shared import opts, state
-import modules.shared as shared
 
 
 SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
 SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
 
 
@@ -25,19 +23,29 @@ def setup_img2img_steps(p, steps=None):
 approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
 approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
 
 
 
 
-def single_sample_to_image(sample, approximation=None):
+def samples_to_images_tensor(sample, approximation=None, model=None):
+    '''latents -> images [-1, 1]'''
     if approximation is None:
     if approximation is None:
         approximation = approximation_indexes.get(opts.show_progress_type, 0)
         approximation = approximation_indexes.get(opts.show_progress_type, 0)
 
 
     if approximation == 2:
     if approximation == 2:
-        x_sample = sd_vae_approx.cheap_approximation(sample) * 0.5 + 0.5
+        x_sample = sd_vae_approx.cheap_approximation(sample)
     elif approximation == 1:
     elif approximation == 1:
-        x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() * 0.5 + 0.5
+        x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype)).detach()
     elif approximation == 3:
     elif approximation == 3:
         x_sample = sample * 1.5
         x_sample = sample * 1.5
-        x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
+        x_sample = sd_vae_taesd.decoder_model()(x_sample.to(devices.device, devices.dtype)).detach()
+        x_sample = x_sample * 2 - 1
     else:
     else:
-        x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5
+        if model is None:
+            model = shared.sd_model
+        x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
+
+    return x_sample
+
+
+def single_sample_to_image(sample, approximation=None):
+    x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5
 
 
     x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
     x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
     x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
     x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
@@ -46,6 +54,12 @@ def single_sample_to_image(sample, approximation=None):
     return Image.fromarray(x_sample)
     return Image.fromarray(x_sample)
 
 
 
 
+def decode_first_stage(model, x):
+    x = x.to(devices.dtype_vae)
+    approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0)
+    return samples_to_images_tensor(x, approx_index, model)
+
+
 def sample_to_image(samples, index=0, approximation=None):
 def sample_to_image(samples, index=0, approximation=None):
     return single_sample_to_image(samples[index], approximation)
     return single_sample_to_image(samples[index], approximation)
 
 
@@ -54,6 +68,24 @@ def samples_to_image_grid(samples, approximation=None):
     return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
     return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
 
 
 
 
+def images_tensor_to_samples(image, approximation=None, model=None):
+    '''image[0, 1] -> latent'''
+    if approximation is None:
+        approximation = approximation_indexes.get(opts.sd_vae_encode_method, 0)
+
+    if approximation == 3:
+        image = image.to(devices.device, devices.dtype)
+        x_latent = sd_vae_taesd.encoder_model()(image)
+    else:
+        if model is None:
+            model = shared.sd_model
+        image = image.to(shared.device, dtype=devices.dtype_vae)
+        image = image * 2 - 1
+        x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
+
+    return x_latent
+
+
 def store_latent(decoded):
 def store_latent(decoded):
     state.current_latent = decoded
     state.current_latent = decoded
 
 
@@ -85,11 +117,13 @@ class InterruptedException(BaseException):
     pass
     pass
 
 
 
 
-if opts.randn_source == "CPU":
+def replace_torchsde_browinan():
     import torchsde._brownian.brownian_interval
     import torchsde._brownian.brownian_interval
 
 
     def torchsde_randn(size, dtype, device, seed):
     def torchsde_randn(size, dtype, device, seed):
-        generator = torch.Generator(devices.cpu).manual_seed(int(seed))
-        return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
+        return devices.randn_local(seed, size).to(device=device, dtype=dtype)
 
 
     torchsde._brownian.brownian_interval._randn = torchsde_randn
     torchsde._brownian.brownian_interval._randn = torchsde_randn
+
+
+replace_torchsde_browinan()

+ 38 - 5
modules/sd_samplers_kdiffusion.py

@@ -4,6 +4,7 @@ import inspect
 import k_diffusion.sampling
 import k_diffusion.sampling
 from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra
 from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra
 
 
+from modules.processing import StableDiffusionProcessing
 from modules.shared import opts, state
 from modules.shared import opts, state
 import modules.shared as shared
 import modules.shared as shared
 from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
 from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
@@ -30,6 +31,7 @@ samplers_k_diffusion = [
     ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
     ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
     ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
     ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
     ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
     ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
+    ('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}),
     ('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras'}),
     ('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras'}),
 ]
 ]
 
 
@@ -260,10 +262,7 @@ class TorchHijack:
             if noise.shape == x.shape:
             if noise.shape == x.shape:
                 return noise
                 return noise
 
 
-        if opts.randn_source == "CPU" or x.device.type == 'mps':
-            return torch.randn_like(x, device=devices.cpu).to(x.device)
-        else:
-            return torch.randn_like(x)
+        return devices.randn_like(x)
 
 
 
 
 class KDiffusionSampler:
 class KDiffusionSampler:
@@ -282,6 +281,14 @@ class KDiffusionSampler:
         self.last_latent = None
         self.last_latent = None
         self.s_min_uncond = None
         self.s_min_uncond = None
 
 
+        # NOTE: These are also defined in the StableDiffusionProcessing class.
+        # They should have been here to begin with but we're going to
+        # leave that class __init__ signature alone.
+        self.s_churn = 0.0
+        self.s_tmin = 0.0
+        self.s_tmax = float('inf')
+        self.s_noise = 1.0
+
         self.conditioning_key = sd_model.model.conditioning_key
         self.conditioning_key = sd_model.model.conditioning_key
 
 
     def callback_state(self, d):
     def callback_state(self, d):
@@ -316,7 +323,7 @@ class KDiffusionSampler:
     def number_of_needed_noises(self, p):
     def number_of_needed_noises(self, p):
         return p.steps
         return p.steps
 
 
-    def initialize(self, p):
+    def initialize(self, p: StableDiffusionProcessing):
         self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
         self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
         self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
         self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
         self.model_wrap_cfg.step = 0
         self.model_wrap_cfg.step = 0
@@ -337,6 +344,29 @@ class KDiffusionSampler:
 
 
             extra_params_kwargs['eta'] = self.eta
             extra_params_kwargs['eta'] = self.eta
 
 
+        if len(self.extra_params) > 0:
+            s_churn = getattr(opts, 's_churn', p.s_churn)
+            s_tmin = getattr(opts, 's_tmin', p.s_tmin)
+            s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
+            s_noise = getattr(opts, 's_noise', p.s_noise)
+
+            if s_churn != self.s_churn:
+                extra_params_kwargs['s_churn'] = s_churn
+                p.s_churn = s_churn
+                p.extra_generation_params['Sigma churn'] = s_churn
+            if s_tmin != self.s_tmin:
+                extra_params_kwargs['s_tmin'] = s_tmin
+                p.s_tmin = s_tmin
+                p.extra_generation_params['Sigma tmin'] = s_tmin
+            if s_tmax != self.s_tmax:
+                extra_params_kwargs['s_tmax'] = s_tmax
+                p.s_tmax = s_tmax
+                p.extra_generation_params['Sigma tmax'] = s_tmax
+            if s_noise != self.s_noise:
+                extra_params_kwargs['s_noise'] = s_noise
+                p.s_noise = s_noise
+                p.extra_generation_params['Sigma noise'] = s_noise
+
         return extra_params_kwargs
         return extra_params_kwargs
 
 
     def get_sigmas(self, p, steps):
     def get_sigmas(self, p, steps):
@@ -378,6 +408,9 @@ class KDiffusionSampler:
             sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
             sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
 
 
             sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
             sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
+        elif self.config is not None and self.config.options.get('scheduler', None) == 'exponential':
+            m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
+            sigmas = k_diffusion.sampling.get_sigmas_exponential(n=steps, sigma_min=m_sigma_min, sigma_max=m_sigma_max, device=shared.device)
         else:
         else:
             sigmas = self.model_wrap.get_sigmas(steps)
             sigmas = self.model_wrap.get_sigmas(steps)
 
 

+ 66 - 13
modules/sd_vae.py

@@ -1,6 +1,8 @@
 import os
 import os
 import collections
 import collections
-from modules import paths, shared, devices, script_callbacks, sd_models
+from dataclasses import dataclass
+
+from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks
 import glob
 import glob
 from copy import deepcopy
 from copy import deepcopy
 
 
@@ -16,6 +18,7 @@ checkpoint_info = None
 
 
 checkpoints_loaded = collections.OrderedDict()
 checkpoints_loaded = collections.OrderedDict()
 
 
+
 def get_base_vae(model):
 def get_base_vae(model):
     if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
     if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
         return base_vae
         return base_vae
@@ -50,6 +53,7 @@ def get_filename(filepath):
 
 
 
 
 def refresh_vae_list():
 def refresh_vae_list():
+    global vae_dict
     vae_dict.clear()
     vae_dict.clear()
 
 
     paths = [
     paths = [
@@ -83,6 +87,8 @@ def refresh_vae_list():
         name = get_filename(filepath)
         name = get_filename(filepath)
         vae_dict[name] = filepath
         vae_dict[name] = filepath
 
 
+    vae_dict = dict(sorted(vae_dict.items(), key=lambda item: shared.natural_sort_key(item[0])))
+
 
 
 def find_vae_near_checkpoint(checkpoint_file):
 def find_vae_near_checkpoint(checkpoint_file):
     checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0]
     checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0]
@@ -93,27 +99,74 @@ def find_vae_near_checkpoint(checkpoint_file):
     return None
     return None
 
 
 
 
-def resolve_vae(checkpoint_file):
-    if shared.cmd_opts.vae_path is not None:
-        return shared.cmd_opts.vae_path, 'from commandline argument'
+@dataclass
+class VaeResolution:
+    vae: str = None
+    source: str = None
+    resolved: bool = True
 
 
-    is_automatic = shared.opts.sd_vae in {"Automatic", "auto"}  # "auto" for people with old config
+    def tuple(self):
+        return self.vae, self.source
 
 
-    vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
-    if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic):
-        return vae_near_checkpoint, 'found near the checkpoint'
 
 
+def is_automatic():
+    return shared.opts.sd_vae in {"Automatic", "auto"}  # "auto" for people with old config
+
+
+def resolve_vae_from_setting() -> VaeResolution:
     if shared.opts.sd_vae == "None":
     if shared.opts.sd_vae == "None":
-        return None, None
+        return VaeResolution()
 
 
     vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
     vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
     if vae_from_options is not None:
     if vae_from_options is not None:
-        return vae_from_options, 'specified in settings'
+        return VaeResolution(vae_from_options, 'specified in settings')
 
 
-    if not is_automatic:
+    if not is_automatic():
         print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
         print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
 
 
-    return None, None
+    return VaeResolution(resolved=False)
+
+
+def resolve_vae_from_user_metadata(checkpoint_file) -> VaeResolution:
+    metadata = extra_networks.get_user_metadata(checkpoint_file)
+    vae_metadata = metadata.get("vae", None)
+    if vae_metadata is not None and vae_metadata != "Automatic":
+        if vae_metadata == "None":
+            return VaeResolution()
+
+        vae_from_metadata = vae_dict.get(vae_metadata, None)
+        if vae_from_metadata is not None:
+            return VaeResolution(vae_from_metadata, "from user metadata")
+
+    return VaeResolution(resolved=False)
+
+
+def resolve_vae_near_checkpoint(checkpoint_file) -> VaeResolution:
+    vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
+    if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic):
+        return VaeResolution(vae_near_checkpoint, 'found near the checkpoint')
+
+    return VaeResolution(resolved=False)
+
+
+def resolve_vae(checkpoint_file) -> VaeResolution:
+    if shared.cmd_opts.vae_path is not None:
+        return VaeResolution(shared.cmd_opts.vae_path, 'from commandline argument')
+
+    if shared.opts.sd_vae_overrides_per_model_preferences and not is_automatic():
+        return resolve_vae_from_setting()
+
+    res = resolve_vae_from_user_metadata(checkpoint_file)
+    if res.resolved:
+        return res
+
+    res = resolve_vae_near_checkpoint(checkpoint_file)
+    if res.resolved:
+        return res
+
+    res = resolve_vae_from_setting()
+
+    return res
 
 
 
 
 def load_vae_dict(filename, map_location):
 def load_vae_dict(filename, map_location):
@@ -187,7 +240,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
     checkpoint_file = checkpoint_info.filename
     checkpoint_file = checkpoint_info.filename
 
 
     if vae_file == unspecified:
     if vae_file == unspecified:
-        vae_file, vae_source = resolve_vae(checkpoint_file)
+        vae_file, vae_source = resolve_vae(checkpoint_file).tuple()
     else:
     else:
         vae_source = "from function argument"
         vae_source = "from function argument"
 
 

+ 1 - 1
modules/sd_vae_approx.py

@@ -81,6 +81,6 @@ def cheap_approximation(sample):
 
 
     coefs = torch.tensor(coeffs).to(sample.device)
     coefs = torch.tensor(coeffs).to(sample.device)
 
 
-    x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
+    x_sample = torch.einsum("...lxy,lr -> ...rxy", sample, coefs)
 
 
     return x_sample
     return x_sample

+ 44 - 8
modules/sd_vae_taesd.py

@@ -44,7 +44,17 @@ def decoder():
     )
     )
 
 
 
 
-class TAESD(nn.Module):
+def encoder():
+    return nn.Sequential(
+        conv(3, 64), Block(64, 64),
+        conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
+        conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
+        conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
+        conv(64, 4),
+    )
+
+
+class TAESDDecoder(nn.Module):
     latent_magnitude = 3
     latent_magnitude = 3
     latent_shift = 0.5
     latent_shift = 0.5
 
 
@@ -55,21 +65,28 @@ class TAESD(nn.Module):
         self.decoder.load_state_dict(
         self.decoder.load_state_dict(
             torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
             torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
 
 
-    @staticmethod
-    def unscale_latents(x):
-        """[0, 1] -> raw latents"""
-        return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
+
+class TAESDEncoder(nn.Module):
+    latent_magnitude = 3
+    latent_shift = 0.5
+
+    def __init__(self, encoder_path="taesd_encoder.pth"):
+        """Initialize pretrained TAESD on the given device from the given checkpoints."""
+        super().__init__()
+        self.encoder = encoder()
+        self.encoder.load_state_dict(
+            torch.load(encoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
 
 
 
 
 def download_model(model_path, model_url):
 def download_model(model_path, model_url):
     if not os.path.exists(model_path):
     if not os.path.exists(model_path):
         os.makedirs(os.path.dirname(model_path), exist_ok=True)
         os.makedirs(os.path.dirname(model_path), exist_ok=True)
 
 
-        print(f'Downloading TAESD decoder to: {model_path}')
+        print(f'Downloading TAESD model to: {model_path}')
         torch.hub.download_url_to_file(model_url, model_path)
         torch.hub.download_url_to_file(model_url, model_path)
 
 
 
 
-def model():
+def decoder_model():
     model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
     model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
     loaded_model = sd_vae_taesd_models.get(model_name)
     loaded_model = sd_vae_taesd_models.get(model_name)
 
 
@@ -78,7 +95,7 @@ def model():
         download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
         download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
 
 
         if os.path.exists(model_path):
         if os.path.exists(model_path):
-            loaded_model = TAESD(model_path)
+            loaded_model = TAESDDecoder(model_path)
             loaded_model.eval()
             loaded_model.eval()
             loaded_model.to(devices.device, devices.dtype)
             loaded_model.to(devices.device, devices.dtype)
             sd_vae_taesd_models[model_name] = loaded_model
             sd_vae_taesd_models[model_name] = loaded_model
@@ -86,3 +103,22 @@ def model():
             raise FileNotFoundError('TAESD model not found')
             raise FileNotFoundError('TAESD model not found')
 
 
     return loaded_model.decoder
     return loaded_model.decoder
+
+
+def encoder_model():
+    model_name = "taesdxl_encoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_encoder.pth"
+    loaded_model = sd_vae_taesd_models.get(model_name)
+
+    if loaded_model is None:
+        model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name)
+        download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
+
+        if os.path.exists(model_path):
+            loaded_model = TAESDEncoder(model_path)
+            loaded_model.eval()
+            loaded_model.to(devices.device, devices.dtype)
+            sd_vae_taesd_models[model_name] = loaded_model
+        else:
+            raise FileNotFoundError('TAESD model not found')
+
+    return loaded_model.encoder

+ 129 - 45
modules/shared.py

@@ -51,15 +51,34 @@ restricted_opts = {
 
 
 # https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json
 # https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json
 gradio_hf_hub_themes = [
 gradio_hf_hub_themes = [
+    "gradio/base",
     "gradio/glass",
     "gradio/glass",
     "gradio/monochrome",
     "gradio/monochrome",
     "gradio/seafoam",
     "gradio/seafoam",
     "gradio/soft",
     "gradio/soft",
-    "freddyaboulton/dracula_revamped",
     "gradio/dracula_test",
     "gradio/dracula_test",
     "abidlabs/dracula_test",
     "abidlabs/dracula_test",
+    "abidlabs/Lime",
     "abidlabs/pakistan",
     "abidlabs/pakistan",
+    "Ama434/neutral-barlow",
     "dawood/microsoft_windows",
     "dawood/microsoft_windows",
+    "finlaymacklon/smooth_slate",
+    "Franklisi/darkmode",
+    "freddyaboulton/dracula_revamped",
+    "freddyaboulton/test-blue",
+    "gstaff/xkcd",
+    "Insuz/Mocha",
+    "Insuz/SimpleIndigo",
+    "JohnSmith9982/small_and_pretty",
+    "nota-ai/theme",
+    "nuttea/Softblue",
+    "ParityError/Anime",
+    "reilnuud/polite",
+    "remilia/Ghostly",
+    "rottenlittlecreature/Moon_Goblin",
+    "step-3-profit/Midnight-Deep",
+    "Taithrah/Minimal",
+    "ysharma/huggingface",
     "ysharma/steampunk"
     "ysharma/steampunk"
 ]
 ]
 
 
@@ -220,12 +239,19 @@ class State:
             return
             return
 
 
         import modules.sd_samplers
         import modules.sd_samplers
-        if opts.show_progress_grid:
-            self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
-        else:
-            self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
 
 
-        self.current_image_sampling_step = self.sampling_step
+        try:
+            if opts.show_progress_grid:
+                self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
+            else:
+                self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
+
+            self.current_image_sampling_step = self.sampling_step
+
+        except Exception:
+            # when switching models during genration, VAE would be on CPU, so creating an image will fail.
+            # we silently ignore this error
+            errors.record_exception()
 
 
     def assign_current_image(self, image):
     def assign_current_image(self, image):
         self.current_image = image
         self.current_image = image
@@ -252,6 +278,7 @@ class OptionInfo:
         self.onchange = onchange
         self.onchange = onchange
         self.section = section
         self.section = section
         self.refresh = refresh
         self.refresh = refresh
+        self.do_not_save = False
 
 
         self.comment_before = comment_before
         self.comment_before = comment_before
         """HTML text that will be added after label in UI"""
         """HTML text that will be added after label in UI"""
@@ -279,7 +306,16 @@ class OptionInfo:
         self.comment_after += " <span class='info'>(requires restart)</span>"
         self.comment_after += " <span class='info'>(requires restart)</span>"
         return self
         return self
 
 
+    def needs_reload_ui(self):
+        self.comment_after += " <span class='info'>(requires Reload UI)</span>"
+        return self
+
+
+class OptionHTML(OptionInfo):
+    def __init__(self, text):
+        super().__init__(str(text).strip(), label='', component=lambda **kwargs: gr.HTML(elem_classes="settings-info", **kwargs))
 
 
+        self.do_not_save = True
 
 
 
 
 def options_section(section_identifier, options_dict):
 def options_section(section_identifier, options_dict):
@@ -349,6 +385,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
     "temp_dir":  OptionInfo("", "Directory for temporary images; leave empty for default"),
     "temp_dir":  OptionInfo("", "Directory for temporary images; leave empty for default"),
     "clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
     "clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
 
 
+    "save_incomplete_images": OptionInfo(False, "Save incomplete images").info("save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output."),
 }))
 }))
 
 
 options_templates.update(options_section(('saving-paths', "Paths for saving"), {
 options_templates.update(options_section(('saving-paths', "Paths for saving"), {
@@ -386,13 +423,15 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
 
 
 options_templates.update(options_section(('system', "System"), {
 options_templates.update(options_section(('system', "System"), {
     "auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}),
     "auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}),
-    "show_warnings": OptionInfo(False, "Show warnings in console."),
+    "show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(),
+    "show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_reload_ui(),
     "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
     "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
     "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
     "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
     "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
     "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
     "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
     "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
     "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
     "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
     "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
     "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
+    "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
 }))
 }))
 
 
 options_templates.update(options_section(('training', "Training"), {
 options_templates.update(options_section(('training', "Training"), {
@@ -412,24 +451,17 @@ options_templates.update(options_section(('training', "Training"), {
 
 
 options_templates.update(options_section(('sd', "Stable Diffusion"), {
 options_templates.update(options_section(('sd', "Stable Diffusion"), {
     "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
     "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
-    "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
-    "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
-    "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
-    "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
+    "sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
+    "sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"),
+    "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"),
     "sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
     "sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
-    "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
-    "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
-    "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
-    "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
-    "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}),
-    "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
+    "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds").needs_reload_ui(),
     "enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
     "enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
     "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
     "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
     "comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
     "comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
     "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
     "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
     "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
     "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
-    "auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
-    "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
+    "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
 }))
 }))
 
 
 options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
 options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
@@ -439,6 +471,35 @@ options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
     "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
     "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
 }))
 }))
 
 
+options_templates.update(options_section(('vae', "VAE"), {
+    "sd_vae_explanation": OptionHTML("""
+<abbr title='Variational autoencoder'>VAE</abbr> is a neural network that transforms a standard <abbr title='red/green/blue'>RGB</abbr>
+image into latent space representation and back. Latent space representation is what stable diffusion is working on during sampling
+(i.e. when the progress bar is between empty and full). For txt2img, VAE is used to create a resulting image after the sampling is finished.
+For img2img, VAE is used to process user's input image before the sampling, and to create an image after sampling.
+"""),
+    "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
+    "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
+    "sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"),
+    "auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
+    "sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
+    "sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to decode latent to image"),
+}))
+
+options_templates.update(options_section(('img2img', "img2img"), {
+    "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+    "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
+    "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
+    "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
+    "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill transparent parts of the input image with this color.", ui_components.FormColorPicker, {}),
+    "img2img_editor_height": OptionInfo(720, "Height of the image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_reload_ui(),
+    "img2img_sketch_default_brush_color": OptionInfo("#ffffff", "Sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img sketch").needs_reload_ui(),
+    "img2img_inpaint_mask_brush_color": OptionInfo("#ffffff", "Inpaint mask brush color", ui_components.FormColorPicker,  {}).info("brush color of inpaint mask").needs_reload_ui(),
+    "img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_reload_ui(),
+    "return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
+    "return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
+}))
+
 options_templates.update(options_section(('optimizations', "Optimizations"), {
 options_templates.update(options_section(('optimizations', "Optimizations"), {
     "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
     "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
     "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
     "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
@@ -446,7 +507,7 @@ options_templates.update(options_section(('optimizations', "Optimizations"), {
     "token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
     "token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
     "token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
     "token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
     "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length").info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
     "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length").info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
-    "experimental_persistent_cond_cache": OptionInfo(False, "persistent cond cache").info("Experimental, keep cond caches across jobs, reduce overhead."),
+    "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("Do not recalculate conds from prompts if prompts have not changed since previous calculation"),
 }))
 }))
 
 
 options_templates.update(options_section(('compatibility', "Compatibility"), {
 options_templates.update(options_section(('compatibility', "Compatibility"), {
@@ -458,7 +519,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
     "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
     "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
 }))
 }))
 
 
-options_templates.update(options_section(('interrogate', "Interrogate Options"), {
+options_templates.update(options_section(('interrogate', "Interrogate"), {
     "interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"),
     "interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"),
     "interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"),
     "interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"),
     "interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
     "interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
@@ -482,19 +543,17 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
     "extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
     "extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
     "extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
     "extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
     "extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
     "extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
-    "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_restart(),
+    "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
     "textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
     "textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
     "textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"),
     "textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"),
     "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks),
     "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks),
 }))
 }))
 
 
 options_templates.update(options_section(('ui', "User interface"), {
 options_templates.update(options_section(('ui', "User interface"), {
-    "localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_restart(),
-    "gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).needs_restart(),
-    "img2img_editor_height": OptionInfo(720, "img2img: height of image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_restart(),
+    "localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(),
+    "gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).info("you can also manually enter any of themes from the <a href='https://huggingface.co/spaces/gradio/theme-gallery'>gallery</a>.").needs_reload_ui(),
+    "gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"),
     "return_grid": OptionInfo(True, "Show grid in results for web"),
     "return_grid": OptionInfo(True, "Show grid in results for web"),
-    "return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
-    "return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
     "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
     "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
     "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
     "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
     "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
     "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
@@ -503,21 +562,22 @@ options_templates.update(options_section(('ui', "User interface"), {
     "js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
     "js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
     "js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
     "js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
     "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
     "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
-    "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_restart(),
-    "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_restart(),
+    "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_reload_ui(),
+    "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(),
     "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
     "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
     "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
     "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
     "keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
     "keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
     "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
     "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
-    "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(),
-    "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
-    "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
-    "ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_restart(),
-    "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires sampler selection").needs_restart(),
-    "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_restart(),
-    "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_restart(),
+    "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(),
+    "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_reload_ui(),
+    "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_reload_ui(),
+    "ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(),
+    "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(),
+    "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(),
+    "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
 }))
 }))
 
 
+
 options_templates.update(options_section(('infotext', "Infotext"), {
 options_templates.update(options_section(('infotext', "Infotext"), {
     "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
     "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
     "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
     "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
@@ -545,17 +605,18 @@ options_templates.update(options_section(('ui', "Live previews"), {
 }))
 }))
 
 
 options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
 options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
-    "hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}).needs_restart(),
+    "hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}).needs_reload_ui(),
     "eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"),
     "eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"),
     "eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"),
     "eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"),
     "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
     "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
-    's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
-    's_tmin':  OptionInfo(0.0, "sigma tmin",  gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
-    's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
-    'k_sched_type':  OptionInfo("Automatic", "scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
+    's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}).info('amount of stochasticity; only applies to Euler, Heun, and DPM2'),
+    's_tmin':  OptionInfo(0.0, "sigma tmin",  gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}).info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'),
+    's_tmax':  OptionInfo(0.0, "sigma tmax",  gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}).info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
+    's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}).info('amount of additional noise to counteract loss of detail during sampling; only applies to Euler, Heun, and DPM2'),
+    'k_sched_type':  OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
     'sigma_min': OptionInfo(0.0, "sigma min", gr.Number).info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
     'sigma_min': OptionInfo(0.0, "sigma min", gr.Number).info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
-    'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise schedule"),
-    'rho':  OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a more steep noise schedule (decreases faster)"),
+    'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
+    'rho':  OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"),
     'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
     'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
     'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
     'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
     'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
     'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
@@ -595,6 +656,9 @@ class Options:
                 assert not cmd_opts.freeze_settings, "changing settings is disabled"
                 assert not cmd_opts.freeze_settings, "changing settings is disabled"
 
 
                 info = opts.data_labels.get(key, None)
                 info = opts.data_labels.get(key, None)
+                if info.do_not_save:
+                    return
+
                 comp_args = info.component_args if info else None
                 comp_args = info.component_args if info else None
                 if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
                 if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
                     raise RuntimeError(f"not possible to set {key} because it is restricted")
                     raise RuntimeError(f"not possible to set {key} because it is restricted")
@@ -624,6 +688,9 @@ class Options:
         if oldval == value:
         if oldval == value:
             return False
             return False
 
 
+        if self.data_labels[key].do_not_save:
+            return False
+
         try:
         try:
             setattr(self, key, value)
             setattr(self, key, value)
         except RuntimeError:
         except RuntimeError:
@@ -667,6 +734,10 @@ class Options:
         with open(filename, "r", encoding="utf8") as file:
         with open(filename, "r", encoding="utf8") as file:
             self.data = json.load(file)
             self.data = json.load(file)
 
 
+        # 1.6.0 VAE defaults
+        if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None:
+            self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default')
+
         # 1.1.1 quicksettings list migration
         # 1.1.1 quicksettings list migration
         if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
         if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
             self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
             self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
@@ -800,13 +871,19 @@ def reload_gradio_theme(theme_name=None):
         gradio_theme = gr.themes.Default(**default_theme_args)
         gradio_theme = gr.themes.Default(**default_theme_args)
     else:
     else:
         try:
         try:
-            gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
+            theme_cache_dir = os.path.join(script_path, 'tmp', 'gradio_themes')
+            theme_cache_path = os.path.join(theme_cache_dir, f'{theme_name.replace("/", "_")}.json')
+            if opts.gradio_themes_cache and os.path.exists(theme_cache_path):
+                gradio_theme = gr.themes.ThemeClass.load(theme_cache_path)
+            else:
+                os.makedirs(theme_cache_dir, exist_ok=True)
+                gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
+                gradio_theme.dump(theme_cache_path)
         except Exception as e:
         except Exception as e:
             errors.display(e, "changing gradio theme")
             errors.display(e, "changing gradio theme")
             gradio_theme = gr.themes.Default(**default_theme_args)
             gradio_theme = gr.themes.Default(**default_theme_args)
 
 
 
 
-
 class TotalTQDM:
 class TotalTQDM:
     def __init__(self):
     def __init__(self):
         self._tqdm = None
         self._tqdm = None
@@ -890,3 +967,10 @@ def walk_files(path, allowed_extensions=None):
                 continue
                 continue
 
 
             yield os.path.join(root, filename)
             yield os.path.join(root, filename)
+
+
+def ldm_print(*args, **kwargs):
+    if opts.hide_ldm_prints:
+        return
+
+    print(*args, **kwargs)

+ 1 - 4
modules/styles.py

@@ -106,10 +106,7 @@ class StyleDatabase:
         if os.path.exists(path):
         if os.path.exists(path):
             shutil.copy(path, f"{path}.bak")
             shutil.copy(path, f"{path}.bak")
 
 
-        fd = os.open(path, os.O_RDWR | os.O_CREAT)
-        with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
-            # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
-            # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
+        with open(path, "w", encoding="utf-8-sig", newline='') as file:
             writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
             writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
             writer.writeheader()
             writer.writeheader()
             writer.writerows(style._asdict() for k, style in self.styles.items())
             writer.writerows(style._asdict() for k, style in self.styles.items())

+ 3 - 1
modules/textual_inversion/textual_inversion.py

@@ -13,7 +13,7 @@ import numpy as np
 from PIL import Image, PngImagePlugin
 from PIL import Image, PngImagePlugin
 from torch.utils.tensorboard import SummaryWriter
 from torch.utils.tensorboard import SummaryWriter
 
 
-from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
+from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
 import modules.textual_inversion.dataset
 import modules.textual_inversion.dataset
 from modules.textual_inversion.learn_schedule import LearnRateScheduler
 from modules.textual_inversion.learn_schedule import LearnRateScheduler
 
 
@@ -387,6 +387,8 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
 
 
 
 
 def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
 def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+    from modules import processing
+
     save_embedding_every = save_embedding_every or 0
     save_embedding_every = save_embedding_every or 0
     create_image_every = create_image_every or 0
     create_image_every = create_image_every or 0
     template_file = textual_inversion_templates.get(template_filename, None)
     template_file = textual_inversion_templates.get(template_filename, None)

+ 2 - 1
modules/txt2img.py

@@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html
 import gradio as gr
 import gradio as gr
 
 
 
 
-def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
+def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
     override_settings = create_override_settings_dict(override_settings_texts)
     override_settings = create_override_settings_dict(override_settings_texts)
 
 
     p = processing.StableDiffusionProcessingTxt2Img(
     p = processing.StableDiffusionProcessingTxt2Img(
@@ -41,6 +41,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
         hr_second_pass_steps=hr_second_pass_steps,
         hr_second_pass_steps=hr_second_pass_steps,
         hr_resize_x=hr_resize_x,
         hr_resize_x=hr_resize_x,
         hr_resize_y=hr_resize_y,
         hr_resize_y=hr_resize_y,
+        hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
         hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None,
         hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None,
         hr_prompt=hr_prompt,
         hr_prompt=hr_prompt,
         hr_negative_prompt=hr_negative_prompt,
         hr_negative_prompt=hr_negative_prompt,

+ 160 - 213
modules/ui.py

@@ -12,39 +12,38 @@ import numpy as np
 from PIL import Image, PngImagePlugin  # noqa: F401
 from PIL import Image, PngImagePlugin  # noqa: F401
 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
 
 
-from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger
+from modules import gradio_extensons  # noqa: F401
+from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts
 from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
 from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
 from modules.paths import script_path
 from modules.paths import script_path
 from modules.ui_common import create_refresh_button
 from modules.ui_common import create_refresh_button
 from modules.ui_gradio_extensions import reload_javascript
 from modules.ui_gradio_extensions import reload_javascript
 
 
-
 from modules.shared import opts, cmd_opts
 from modules.shared import opts, cmd_opts
 
 
-import modules.codeformer_model
 import modules.generation_parameters_copypaste as parameters_copypaste
 import modules.generation_parameters_copypaste as parameters_copypaste
-import modules.gfpgan_model
-import modules.hypernetworks.ui
-import modules.scripts
+import modules.hypernetworks.ui as hypernetworks_ui
+import modules.textual_inversion.ui as textual_inversion_ui
+import modules.textual_inversion.textual_inversion as textual_inversion
 import modules.shared as shared
 import modules.shared as shared
-import modules.styles
-import modules.textual_inversion.ui
+import modules.images
 from modules import prompt_parser
 from modules import prompt_parser
 from modules.sd_hijack import model_hijack
 from modules.sd_hijack import model_hijack
 from modules.sd_samplers import samplers, samplers_for_img2img
 from modules.sd_samplers import samplers, samplers_for_img2img
-from modules.textual_inversion import textual_inversion
-import modules.hypernetworks.ui
 from modules.generation_parameters_copypaste import image_from_url_text
 from modules.generation_parameters_copypaste import image_from_url_text
-import modules.extras
 
 
 create_setting_component = ui_settings.create_setting_component
 create_setting_component = ui_settings.create_setting_component
 
 
 warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
 warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
+warnings.filterwarnings("default" if opts.show_gradio_deprecation_warnings else "ignore", category=gr.deprecation.GradioDeprecationWarning)
 
 
 # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
 # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
 mimetypes.init()
 mimetypes.init()
 mimetypes.add_type('application/javascript', '.js')
 mimetypes.add_type('application/javascript', '.js')
 
 
+# Likewise, add explicit content-type header for certain missing image types
+mimetypes.add_type('image/webp', '.webp')
+
 if not cmd_opts.share and not cmd_opts.listen:
 if not cmd_opts.share and not cmd_opts.listen:
     # fix gradio phoning home
     # fix gradio phoning home
     gradio.utils.version_check = lambda: None
     gradio.utils.version_check = lambda: None
@@ -92,19 +91,6 @@ def send_gradio_gallery_to_image(x):
     return image_from_url_text(x[0])
     return image_from_url_text(x[0])
 
 
 
 
-def add_style(name: str, prompt: str, negative_prompt: str):
-    if name is None:
-        return [gr_show() for x in range(4)]
-
-    style = modules.styles.PromptStyle(name, prompt, negative_prompt)
-    shared.prompt_styles.styles[style.name] = style
-    # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we
-    # reserialize all styles every time we save them
-    shared.prompt_styles.save_styles(shared.styles_filename)
-
-    return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(2)]
-
-
 def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
 def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
     from modules import processing, devices
     from modules import processing, devices
 
 
@@ -129,13 +115,6 @@ def resize_from_to_html(width, height, scale_by):
     return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{target_width}x{target_height}</span>"
     return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{target_width}x{target_height}</span>"
 
 
 
 
-def apply_styles(prompt, prompt_neg, styles):
-    prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
-    prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles)
-
-    return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])]
-
-
 def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles):
 def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles):
     if mode in {0, 1, 3, 4}:
     if mode in {0, 1, 3, 4}:
         return [interrogation_function(ii_singles[mode]), None]
         return [interrogation_function(ii_singles[mode]), None]
@@ -172,7 +151,6 @@ def interrogate_deepbooru(image):
 def create_seed_inputs(target_interface):
 def create_seed_inputs(target_interface):
     with FormRow(elem_id=f"{target_interface}_seed_row", variant="compact"):
     with FormRow(elem_id=f"{target_interface}_seed_row", variant="compact"):
         seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=f"{target_interface}_seed")
         seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=f"{target_interface}_seed")
-        seed.style(container=False)
         random_seed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_seed", label='Random seed')
         random_seed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_seed", label='Random seed')
         reuse_seed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_seed", label='Reuse seed')
         reuse_seed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_seed", label='Reuse seed')
 
 
@@ -184,7 +162,6 @@ def create_seed_inputs(target_interface):
     with FormRow(visible=False, elem_id=f"{target_interface}_subseed_row") as seed_extra_row_1:
     with FormRow(visible=False, elem_id=f"{target_interface}_subseed_row") as seed_extra_row_1:
         seed_extras.append(seed_extra_row_1)
         seed_extras.append(seed_extra_row_1)
         subseed = gr.Number(label='Variation seed', value=-1, elem_id=f"{target_interface}_subseed")
         subseed = gr.Number(label='Variation seed', value=-1, elem_id=f"{target_interface}_subseed")
-        subseed.style(container=False)
         random_subseed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_subseed")
         random_subseed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_subseed")
         reuse_subseed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_subseed")
         reuse_subseed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_subseed")
         subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=f"{target_interface}_subseed_strength")
         subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=f"{target_interface}_subseed_strength")
@@ -267,71 +244,76 @@ def update_token_counter(text, steps):
     return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
     return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
 
 
 
 
-def create_toprow(is_img2img):
-    id_part = "img2img" if is_img2img else "txt2img"
+class Toprow:
+    """Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation"""
 
 
-    with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
-        with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
-            with gr.Row():
-                with gr.Column(scale=80):
-                    with gr.Row():
-                        prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
+    def __init__(self, is_img2img):
+        id_part = "img2img" if is_img2img else "txt2img"
+        self.id_part = id_part
 
 
-            with gr.Row():
-                with gr.Column(scale=80):
-                    with gr.Row():
-                        negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
-
-        button_interrogate = None
-        button_deepbooru = None
-        if is_img2img:
-            with gr.Column(scale=1, elem_classes="interrogate-col"):
-                button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
-                button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
-
-        with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
-            with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
-                interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
-                skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
-                submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
-
-                skip.click(
-                    fn=lambda: shared.state.skip(),
-                    inputs=[],
-                    outputs=[],
-                )
+        with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
+            with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
+                with gr.Row():
+                    with gr.Column(scale=80):
+                        with gr.Row():
+                            self.prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
+                            self.prompt_img = gr.File(label="", elem_id=f"{id_part}_prompt_image", file_count="single", type="binary", visible=False)
 
 
-                interrupt.click(
-                    fn=lambda: shared.state.interrupt(),
-                    inputs=[],
-                    outputs=[],
-                )
+                with gr.Row():
+                    with gr.Column(scale=80):
+                        with gr.Row():
+                            self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
+
+            self.button_interrogate = None
+            self.button_deepbooru = None
+            if is_img2img:
+                with gr.Column(scale=1, elem_classes="interrogate-col"):
+                    self.button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
+                    self.button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
+
+            with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
+                with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
+                    self.interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
+                    self.skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
+                    self.submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
+
+                    self.skip.click(
+                        fn=lambda: shared.state.skip(),
+                        inputs=[],
+                        outputs=[],
+                    )
 
 
-            with gr.Row(elem_id=f"{id_part}_tools"):
-                paste = ToolButton(value=paste_symbol, elem_id="paste")
-                clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
-                extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
-                prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply")
-                save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create")
-                restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
-
-                token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
-                token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
-                negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
-                negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
-
-                clear_prompt_button.click(
-                    fn=lambda *x: x,
-                    _js="confirm_clear_prompt",
-                    inputs=[prompt, negative_prompt],
-                    outputs=[prompt, negative_prompt],
-                )
+                    self.interrupt.click(
+                        fn=lambda: shared.state.interrupt(),
+                        inputs=[],
+                        outputs=[],
+                    )
+
+                with gr.Row(elem_id=f"{id_part}_tools"):
+                    self.paste = ToolButton(value=paste_symbol, elem_id="paste")
+                    self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
+                    self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
+
+                    self.token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
+                    self.token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
+                    self.negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
+                    self.negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
+
+                    self.clear_prompt_button.click(
+                        fn=lambda *x: x,
+                        _js="confirm_clear_prompt",
+                        inputs=[self.prompt, self.negative_prompt],
+                        outputs=[self.prompt, self.negative_prompt],
+                    )
 
 
-            with gr.Row(elem_id=f"{id_part}_styles_row"):
-                prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
-                create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")
+                self.ui_styles = ui_prompt_styles.UiPromptStyles(id_part, self.prompt, self.negative_prompt)
 
 
-    return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button
+        self.prompt_img.change(
+            fn=modules.images.image_data,
+            inputs=[self.prompt_img],
+            outputs=[self.prompt, self.prompt_img],
+            show_progress=False,
+        )
 
 
 
 
 def setup_progressbar(*args, **kwargs):
 def setup_progressbar(*args, **kwargs):
@@ -415,22 +397,20 @@ def create_ui():
 
 
     parameters_copypaste.reset()
     parameters_copypaste.reset()
 
 
-    modules.scripts.scripts_current = modules.scripts.scripts_txt2img
-    modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
+    scripts.scripts_current = scripts.scripts_txt2img
+    scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
 
 
     with gr.Blocks(analytics_enabled=False) as txt2img_interface:
     with gr.Blocks(analytics_enabled=False) as txt2img_interface:
-        txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=False)
+        toprow = Toprow(is_img2img=False)
 
 
         dummy_component = gr.Label(visible=False)
         dummy_component = gr.Label(visible=False)
-        txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
 
 
-        with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks:
-            from modules import ui_extra_networks
-            extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img')
+        extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs")
+        extra_tabs.__enter__()
 
 
-        with gr.Row().style(equal_height=False):
+        with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, gr.Row(equal_height=False):
             with gr.Column(variant='compact', elem_id="txt2img_settings"):
             with gr.Column(variant='compact', elem_id="txt2img_settings"):
-                modules.scripts.scripts_txt2img.prepare_ui()
+                scripts.scripts_txt2img.prepare_ui()
 
 
                 for category in ordered_ui_categories():
                 for category in ordered_ui_categories():
                     if category == "sampler":
                     if category == "sampler":
@@ -476,6 +456,10 @@ def create_ui():
                                 hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
                                 hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
 
 
                             with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
                             with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
+
+                                hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
+                                create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
+
                                 hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index")
                                 hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index")
 
 
                             with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
                             with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
@@ -498,10 +482,10 @@ def create_ui():
 
 
                     elif category == "scripts":
                     elif category == "scripts":
                         with FormGroup(elem_id="txt2img_script_container"):
                         with FormGroup(elem_id="txt2img_script_container"):
-                            custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
+                            custom_inputs = scripts.scripts_txt2img.setup_ui()
 
 
                     else:
                     else:
-                        modules.scripts.scripts_txt2img.setup_ui_for_section(category)
+                        scripts.scripts_txt2img.setup_ui_for_section(category)
 
 
             hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
             hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
 
 
@@ -532,9 +516,9 @@ def create_ui():
                 _js="submit",
                 _js="submit",
                 inputs=[
                 inputs=[
                     dummy_component,
                     dummy_component,
-                    txt2img_prompt,
-                    txt2img_negative_prompt,
-                    txt2img_prompt_styles,
+                    toprow.prompt,
+                    toprow.negative_prompt,
+                    toprow.ui_styles.dropdown,
                     steps,
                     steps,
                     sampler_index,
                     sampler_index,
                     restore_faces,
                     restore_faces,
@@ -553,6 +537,7 @@ def create_ui():
                     hr_second_pass_steps,
                     hr_second_pass_steps,
                     hr_resize_x,
                     hr_resize_x,
                     hr_resize_y,
                     hr_resize_y,
+                    hr_checkpoint_name,
                     hr_sampler_index,
                     hr_sampler_index,
                     hr_prompt,
                     hr_prompt,
                     hr_negative_prompt,
                     hr_negative_prompt,
@@ -569,12 +554,12 @@ def create_ui():
                 show_progress=False,
                 show_progress=False,
             )
             )
 
 
-            txt2img_prompt.submit(**txt2img_args)
-            submit.click(**txt2img_args)
+            toprow.prompt.submit(**txt2img_args)
+            toprow.submit.click(**txt2img_args)
 
 
             res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False)
             res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False)
 
 
-            restore_progress_button.click(
+            toprow.restore_progress_button.click(
                 fn=progress.restore_progress,
                 fn=progress.restore_progress,
                 _js="restoreProgressTxt2img",
                 _js="restoreProgressTxt2img",
                 inputs=[dummy_component],
                 inputs=[dummy_component],
@@ -587,18 +572,6 @@ def create_ui():
                 show_progress=False,
                 show_progress=False,
             )
             )
 
 
-            txt_prompt_img.change(
-                fn=modules.images.image_data,
-                inputs=[
-                    txt_prompt_img
-                ],
-                outputs=[
-                    txt2img_prompt,
-                    txt_prompt_img
-                ],
-                show_progress=False,
-            )
-
             enable_hr.change(
             enable_hr.change(
                 fn=lambda x: gr_show(x),
                 fn=lambda x: gr_show(x),
                 inputs=[enable_hr],
                 inputs=[enable_hr],
@@ -607,8 +580,8 @@ def create_ui():
             )
             )
 
 
             txt2img_paste_fields = [
             txt2img_paste_fields = [
-                (txt2img_prompt, "Prompt"),
-                (txt2img_negative_prompt, "Negative prompt"),
+                (toprow.prompt, "Prompt"),
+                (toprow.negative_prompt, "Negative prompt"),
                 (steps, "Steps"),
                 (steps, "Steps"),
                 (sampler_index, "Sampler"),
                 (sampler_index, "Sampler"),
                 (restore_faces, "Face restoration"),
                 (restore_faces, "Face restoration"),
@@ -617,34 +590,36 @@ def create_ui():
                 (width, "Size-1"),
                 (width, "Size-1"),
                 (height, "Size-2"),
                 (height, "Size-2"),
                 (batch_size, "Batch size"),
                 (batch_size, "Batch size"),
+                (seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
                 (subseed, "Variation seed"),
                 (subseed, "Variation seed"),
                 (subseed_strength, "Variation seed strength"),
                 (subseed_strength, "Variation seed strength"),
                 (seed_resize_from_w, "Seed resize from-1"),
                 (seed_resize_from_w, "Seed resize from-1"),
                 (seed_resize_from_h, "Seed resize from-2"),
                 (seed_resize_from_h, "Seed resize from-2"),
-                (txt2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
+                (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
                 (denoising_strength, "Denoising strength"),
                 (denoising_strength, "Denoising strength"),
-                (enable_hr, lambda d: "Denoising strength" in d),
-                (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
+                (enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)),
+                (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d))),
                 (hr_scale, "Hires upscale"),
                 (hr_scale, "Hires upscale"),
                 (hr_upscaler, "Hires upscaler"),
                 (hr_upscaler, "Hires upscaler"),
                 (hr_second_pass_steps, "Hires steps"),
                 (hr_second_pass_steps, "Hires steps"),
                 (hr_resize_x, "Hires resize-1"),
                 (hr_resize_x, "Hires resize-1"),
                 (hr_resize_y, "Hires resize-2"),
                 (hr_resize_y, "Hires resize-2"),
+                (hr_checkpoint_name, "Hires checkpoint"),
                 (hr_sampler_index, "Hires sampler"),
                 (hr_sampler_index, "Hires sampler"),
-                (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" else gr.update()),
+                (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
                 (hr_prompt, "Hires prompt"),
                 (hr_prompt, "Hires prompt"),
                 (hr_negative_prompt, "Hires negative prompt"),
                 (hr_negative_prompt, "Hires negative prompt"),
                 (hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
                 (hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
-                *modules.scripts.scripts_txt2img.infotext_fields
+                *scripts.scripts_txt2img.infotext_fields
             ]
             ]
             parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
             parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
             parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
             parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
-                paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None,
+                paste_button=toprow.paste, tabname="txt2img", source_text_component=toprow.prompt, source_image_component=None,
             ))
             ))
 
 
             txt2img_preview_params = [
             txt2img_preview_params = [
-                txt2img_prompt,
-                txt2img_negative_prompt,
+                toprow.prompt,
+                toprow.negative_prompt,
                 steps,
                 steps,
                 sampler_index,
                 sampler_index,
                 cfg_scale,
                 cfg_scale,
@@ -653,24 +628,25 @@ def create_ui():
                 height,
                 height,
             ]
             ]
 
 
-            token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
-            negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
+            toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
+            toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
 
 
-            ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
+        from modules import ui_extra_networks
+        extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
+        ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
 
 
-    modules.scripts.scripts_current = modules.scripts.scripts_img2img
-    modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
+        extra_tabs.__exit__()
 
 
-    with gr.Blocks(analytics_enabled=False) as img2img_interface:
-        img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=True)
+    scripts.scripts_current = scripts.scripts_img2img
+    scripts.scripts_img2img.initialize_scripts(is_img2img=True)
 
 
-        img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
+    with gr.Blocks(analytics_enabled=False) as img2img_interface:
+        toprow = Toprow(is_img2img=True)
 
 
-        with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks:
-            from modules import ui_extra_networks
-            extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'img2img')
+        extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs")
+        extra_tabs.__enter__()
 
 
-        with FormRow().style(equal_height=False):
+        with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, FormRow().style(equal_height=False):
             with gr.Column(variant='compact', elem_id="img2img_settings"):
             with gr.Column(variant='compact', elem_id="img2img_settings"):
                 copy_image_buttons = []
                 copy_image_buttons = []
                 copy_image_destinations = {}
                 copy_image_destinations = {}
@@ -692,19 +668,19 @@ def create_ui():
                     img2img_selected_tab = gr.State(0)
                     img2img_selected_tab = gr.State(0)
 
 
                     with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
                     with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
-                        init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=opts.img2img_editor_height)
+                        init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height)
                         add_copy_image_controls('img2img', init_img)
                         add_copy_image_controls('img2img', init_img)
 
 
                     with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
                     with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
-                        sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=opts.img2img_editor_height)
+                        sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)
                         add_copy_image_controls('sketch', sketch)
                         add_copy_image_controls('sketch', sketch)
 
 
                     with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
                     with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
-                        init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=opts.img2img_editor_height)
+                        init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color)
                         add_copy_image_controls('inpaint', init_img_with_mask)
                         add_copy_image_controls('inpaint', init_img_with_mask)
 
 
                     with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
                     with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
-                        inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=opts.img2img_editor_height)
+                        inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)
                         inpaint_color_sketch_orig = gr.State(None)
                         inpaint_color_sketch_orig = gr.State(None)
                         add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
                         add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
 
 
@@ -764,7 +740,7 @@ def create_ui():
                 with FormRow():
                 with FormRow():
                     resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
                     resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
 
 
-                modules.scripts.scripts_img2img.prepare_ui()
+                scripts.scripts_img2img.prepare_ui()
 
 
                 for category in ordered_ui_categories():
                 for category in ordered_ui_categories():
                     if category == "sampler":
                     if category == "sampler":
@@ -845,7 +821,7 @@ def create_ui():
 
 
                     elif category == "scripts":
                     elif category == "scripts":
                         with FormGroup(elem_id="img2img_script_container"):
                         with FormGroup(elem_id="img2img_script_container"):
-                            custom_inputs = modules.scripts.scripts_img2img.setup_ui()
+                            custom_inputs = scripts.scripts_img2img.setup_ui()
 
 
                     elif category == "inpaint":
                     elif category == "inpaint":
                         with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
                         with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
@@ -876,34 +852,22 @@ def create_ui():
                                     outputs=[inpaint_controls, mask_alpha],
                                     outputs=[inpaint_controls, mask_alpha],
                                 )
                                 )
                     else:
                     else:
-                        modules.scripts.scripts_img2img.setup_ui_for_section(category)
+                        scripts.scripts_img2img.setup_ui_for_section(category)
 
 
             img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
             img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
 
 
             connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
             connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
             connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
             connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
 
 
-            img2img_prompt_img.change(
-                fn=modules.images.image_data,
-                inputs=[
-                    img2img_prompt_img
-                ],
-                outputs=[
-                    img2img_prompt,
-                    img2img_prompt_img
-                ],
-                show_progress=False,
-            )
-
             img2img_args = dict(
             img2img_args = dict(
                 fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
                 fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
                 _js="submit_img2img",
                 _js="submit_img2img",
                 inputs=[
                 inputs=[
                     dummy_component,
                     dummy_component,
                     dummy_component,
                     dummy_component,
-                    img2img_prompt,
-                    img2img_negative_prompt,
-                    img2img_prompt_styles,
+                    toprow.prompt,
+                    toprow.negative_prompt,
+                    toprow.ui_styles.dropdown,
                     init_img,
                     init_img,
                     sketch,
                     sketch,
                     init_img_with_mask,
                     init_img_with_mask,
@@ -962,11 +926,11 @@ def create_ui():
                     inpaint_color_sketch,
                     inpaint_color_sketch,
                     init_img_inpaint,
                     init_img_inpaint,
                 ],
                 ],
-                outputs=[img2img_prompt, dummy_component],
+                outputs=[toprow.prompt, dummy_component],
             )
             )
 
 
-            img2img_prompt.submit(**img2img_args)
-            submit.click(**img2img_args)
+            toprow.prompt.submit(**img2img_args)
+            toprow.submit.click(**img2img_args)
 
 
             res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('img2img')}", inputs=None, outputs=None, show_progress=False)
             res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('img2img')}", inputs=None, outputs=None, show_progress=False)
 
 
@@ -978,7 +942,7 @@ def create_ui():
                 show_progress=False,
                 show_progress=False,
             )
             )
 
 
-            restore_progress_button.click(
+            toprow.restore_progress_button.click(
                 fn=progress.restore_progress,
                 fn=progress.restore_progress,
                 _js="restoreProgressImg2img",
                 _js="restoreProgressImg2img",
                 inputs=[dummy_component],
                 inputs=[dummy_component],
@@ -991,46 +955,22 @@ def create_ui():
                 show_progress=False,
                 show_progress=False,
             )
             )
 
 
-            img2img_interrogate.click(
+            toprow.button_interrogate.click(
                 fn=lambda *args: process_interrogate(interrogate, *args),
                 fn=lambda *args: process_interrogate(interrogate, *args),
                 **interrogate_args,
                 **interrogate_args,
             )
             )
 
 
-            img2img_deepbooru.click(
+            toprow.button_deepbooru.click(
                 fn=lambda *args: process_interrogate(interrogate_deepbooru, *args),
                 fn=lambda *args: process_interrogate(interrogate_deepbooru, *args),
                 **interrogate_args,
                 **interrogate_args,
             )
             )
 
 
-            prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
-            style_dropdowns = [txt2img_prompt_styles, img2img_prompt_styles]
-            style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
-
-            for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
-                button.click(
-                    fn=add_style,
-                    _js="ask_for_style_name",
-                    # Have to pass empty dummy component here, because the JavaScript and Python function have to accept
-                    # the same number of parameters, but we only know the style-name after the JavaScript prompt
-                    inputs=[dummy_component, prompt, negative_prompt],
-                    outputs=[txt2img_prompt_styles, img2img_prompt_styles],
-                )
-
-            for button, (prompt, negative_prompt), styles, js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
-                button.click(
-                    fn=apply_styles,
-                    _js=js_func,
-                    inputs=[prompt, negative_prompt, styles],
-                    outputs=[prompt, negative_prompt, styles],
-                )
-
-            token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
-            negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[img2img_negative_prompt, steps], outputs=[negative_token_counter])
-
-            ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
+            toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
+            toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
 
 
             img2img_paste_fields = [
             img2img_paste_fields = [
-                (img2img_prompt, "Prompt"),
-                (img2img_negative_prompt, "Negative prompt"),
+                (toprow.prompt, "Prompt"),
+                (toprow.negative_prompt, "Negative prompt"),
                 (steps, "Steps"),
                 (steps, "Steps"),
                 (sampler_index, "Sampler"),
                 (sampler_index, "Sampler"),
                 (restore_faces, "Face restoration"),
                 (restore_faces, "Face restoration"),
@@ -1040,28 +980,35 @@ def create_ui():
                 (width, "Size-1"),
                 (width, "Size-1"),
                 (height, "Size-2"),
                 (height, "Size-2"),
                 (batch_size, "Batch size"),
                 (batch_size, "Batch size"),
+                (seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
                 (subseed, "Variation seed"),
                 (subseed, "Variation seed"),
                 (subseed_strength, "Variation seed strength"),
                 (subseed_strength, "Variation seed strength"),
                 (seed_resize_from_w, "Seed resize from-1"),
                 (seed_resize_from_w, "Seed resize from-1"),
                 (seed_resize_from_h, "Seed resize from-2"),
                 (seed_resize_from_h, "Seed resize from-2"),
-                (img2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
+                (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
                 (denoising_strength, "Denoising strength"),
                 (denoising_strength, "Denoising strength"),
                 (mask_blur, "Mask blur"),
                 (mask_blur, "Mask blur"),
-                *modules.scripts.scripts_img2img.infotext_fields
+                *scripts.scripts_img2img.infotext_fields
             ]
             ]
             parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
             parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
             parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings)
             parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings)
             parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
             parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
-                paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None,
+                paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None,
             ))
             ))
 
 
-    modules.scripts.scripts_current = None
+        from modules import ui_extra_networks
+        extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img')
+        ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
+
+        extra_tabs.__exit__()
+
+    scripts.scripts_current = None
 
 
     with gr.Blocks(analytics_enabled=False) as extras_interface:
     with gr.Blocks(analytics_enabled=False) as extras_interface:
         ui_postprocessing.create_ui()
         ui_postprocessing.create_ui()
 
 
     with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
     with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
-        with gr.Row().style(equal_height=False):
+        with gr.Row(equal_height=False):
             with gr.Column(variant='panel'):
             with gr.Column(variant='panel'):
                 image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil")
                 image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil")
 
 
@@ -1086,10 +1033,10 @@ def create_ui():
     modelmerger_ui = ui_checkpoint_merger.UiCheckpointMerger()
     modelmerger_ui = ui_checkpoint_merger.UiCheckpointMerger()
 
 
     with gr.Blocks(analytics_enabled=False) as train_interface:
     with gr.Blocks(analytics_enabled=False) as train_interface:
-        with gr.Row().style(equal_height=False):
+        with gr.Row(equal_height=False):
             gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
             gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
 
 
-        with gr.Row(variant="compact").style(equal_height=False):
+        with gr.Row(variant="compact", equal_height=False):
             with gr.Tabs(elem_id="train_tabs"):
             with gr.Tabs(elem_id="train_tabs"):
 
 
                 with gr.Tab(label="Create embedding", id="create_embedding"):
                 with gr.Tab(label="Create embedding", id="create_embedding"):
@@ -1109,7 +1056,7 @@ def create_ui():
                     new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
                     new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
                     new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
                     new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
                     new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
                     new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
-                    new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func")
+                    new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=hypernetworks_ui.keys, elem_id="train_new_hypernetwork_activation_func")
                     new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")
                     new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")
                     new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")
                     new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")
                     new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")
                     new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")
@@ -1249,12 +1196,12 @@ def create_ui():
 
 
             with gr.Column(elem_id='ti_gallery_container'):
             with gr.Column(elem_id='ti_gallery_container'):
                 ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
                 ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
-                gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(columns=4)
+                gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery', columns=4)
                 gr.HTML(elem_id="ti_progress", value="")
                 gr.HTML(elem_id="ti_progress", value="")
                 ti_outcome = gr.HTML(elem_id="ti_error", value="")
                 ti_outcome = gr.HTML(elem_id="ti_error", value="")
 
 
         create_embedding.click(
         create_embedding.click(
-            fn=modules.textual_inversion.ui.create_embedding,
+            fn=textual_inversion_ui.create_embedding,
             inputs=[
             inputs=[
                 new_embedding_name,
                 new_embedding_name,
                 initialization_text,
                 initialization_text,
@@ -1269,7 +1216,7 @@ def create_ui():
         )
         )
 
 
         create_hypernetwork.click(
         create_hypernetwork.click(
-            fn=modules.hypernetworks.ui.create_hypernetwork,
+            fn=hypernetworks_ui.create_hypernetwork,
             inputs=[
             inputs=[
                 new_hypernetwork_name,
                 new_hypernetwork_name,
                 new_hypernetwork_sizes,
                 new_hypernetwork_sizes,
@@ -1289,7 +1236,7 @@ def create_ui():
         )
         )
 
 
         run_preprocess.click(
         run_preprocess.click(
-            fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
+            fn=wrap_gradio_gpu_call(textual_inversion_ui.preprocess, extra_outputs=[gr.update()]),
             _js="start_training_textual_inversion",
             _js="start_training_textual_inversion",
             inputs=[
             inputs=[
                 dummy_component,
                 dummy_component,
@@ -1325,7 +1272,7 @@ def create_ui():
         )
         )
 
 
         train_embedding.click(
         train_embedding.click(
-            fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
+            fn=wrap_gradio_gpu_call(textual_inversion_ui.train_embedding, extra_outputs=[gr.update()]),
             _js="start_training_textual_inversion",
             _js="start_training_textual_inversion",
             inputs=[
             inputs=[
                 dummy_component,
                 dummy_component,
@@ -1359,7 +1306,7 @@ def create_ui():
         )
         )
 
 
         train_hypernetwork.click(
         train_hypernetwork.click(
-            fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]),
+            fn=wrap_gradio_gpu_call(hypernetworks_ui.train_hypernetwork, extra_outputs=[gr.update()]),
             _js="start_training_textual_inversion",
             _js="start_training_textual_inversion",
             inputs=[
             inputs=[
                 dummy_component,
                 dummy_component,

+ 1 - 1
modules/ui_checkpoint_merger.py

@@ -29,7 +29,7 @@ def modelmerger(*args):
 class UiCheckpointMerger:
 class UiCheckpointMerger:
     def __init__(self):
     def __init__(self):
         with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
         with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
-            with gr.Row().style(equal_height=False):
+            with gr.Row(equal_height=False):
                 with gr.Column(variant='compact'):
                 with gr.Column(variant='compact'):
                     self.interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")
                     self.interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")
 
 

+ 29 - 5
modules/ui_common.py

@@ -134,7 +134,7 @@ Requested path was: {f}
 
 
     with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
     with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
         with gr.Group(elem_id=f"{tabname}_gallery_container"):
         with gr.Group(elem_id=f"{tabname}_gallery_container"):
-            result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(columns=4)
+            result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4)
 
 
         generation_info = None
         generation_info = None
         with gr.Column():
         with gr.Column():
@@ -223,20 +223,44 @@ Requested path was: {f}
 
 
 
 
 def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
 def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
+    refresh_components = refresh_component if isinstance(refresh_component, list) else [refresh_component]
+
+    label = None
+    for comp in refresh_components:
+        label = getattr(comp, 'label', None)
+        if label is not None:
+            break
+
     def refresh():
     def refresh():
         refresh_method()
         refresh_method()
         args = refreshed_args() if callable(refreshed_args) else refreshed_args
         args = refreshed_args() if callable(refreshed_args) else refreshed_args
 
 
         for k, v in args.items():
         for k, v in args.items():
-            setattr(refresh_component, k, v)
+            for comp in refresh_components:
+                setattr(comp, k, v)
 
 
-        return gr.update(**(args or {}))
+        return [gr.update(**(args or {})) for _ in refresh_components] if len(refresh_components) > 1 else gr.update(**(args or {}))
 
 
-    refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
+    refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id, tooltip=f"{label}: refresh" if label else "Refresh")
     refresh_button.click(
     refresh_button.click(
         fn=refresh,
         fn=refresh,
         inputs=[],
         inputs=[],
-        outputs=[refresh_component]
+        outputs=refresh_components
     )
     )
     return refresh_button
     return refresh_button
 
 
+
+def setup_dialog(button_show, dialog, *, button_close=None):
+    """Sets up the UI so that the dialog (gr.Box) is invisible, and is only shown when buttons_show is clicked, in a fullscreen modal window."""
+
+    dialog.visible = False
+
+    button_show.click(
+        fn=lambda: gr.update(visible=True),
+        inputs=[],
+        outputs=[dialog],
+    ).then(fn=None, _js="function(){ popup(gradioApp().getElementById('" + dialog.elem_id + "')); }")
+
+    if button_close:
+        button_close.click(fn=None, _js="closePopup")
+

+ 1 - 1
modules/ui_components.py

@@ -35,7 +35,7 @@ class FormColumn(FormComponent, gr.Column):
 
 
 
 
 class FormGroup(FormComponent, gr.Group):
 class FormGroup(FormComponent, gr.Group):
-    """Same as gr.Row but fits inside gradio forms"""
+    """Same as gr.Group but fits inside gradio forms"""
 
 
     def get_block_name(self):
     def get_block_name(self):
         return "group"
         return "group"

+ 15 - 11
modules/ui_extensions.py

@@ -164,7 +164,7 @@ def extension_table():
             ext_status = ext.status
             ext_status = ext.status
 
 
         style = ""
         style = ""
-        if shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.opts.disable_all_extensions == "all":
+        if shared.cmd_opts.disable_extra_extensions and not ext.is_builtin or shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all":
             style = STYLE_PRIMARY
             style = STYLE_PRIMARY
 
 
         version_link = ext.version
         version_link = ext.version
@@ -533,16 +533,20 @@ def create_ui():
                     apply = gr.Button(value=apply_label, variant="primary")
                     apply = gr.Button(value=apply_label, variant="primary")
                     check = gr.Button(value="Check for updates")
                     check = gr.Button(value="Check for updates")
                     extensions_disable_all = gr.Radio(label="Disable all extensions", choices=["none", "extra", "all"], value=shared.opts.disable_all_extensions, elem_id="extensions_disable_all")
                     extensions_disable_all = gr.Radio(label="Disable all extensions", choices=["none", "extra", "all"], value=shared.opts.disable_all_extensions, elem_id="extensions_disable_all")
-                    extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
-                    extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False)
+                    extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False, container=False)
+                    extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False, container=False)
 
 
                 html = ""
                 html = ""
-                if shared.opts.disable_all_extensions != "none":
-                    html = """
-<span style="color: var(--primary-400);">
-    "Disable all extensions" was set, change it to "none" to load all extensions again
-</span>
-                    """
+
+                if shared.cmd_opts.disable_all_extensions or shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions != "none":
+                    if shared.cmd_opts.disable_all_extensions:
+                        msg = '"--disable-all-extensions" was used, remove it to load all extensions again'
+                    elif shared.opts.disable_all_extensions != "none":
+                        msg = '"Disable all extensions" was set, change it to "none" to load all extensions again'
+                    elif shared.cmd_opts.disable_extra_extensions:
+                        msg = '"--disable-extra-extensions" was used, remove it to load all extensions again'
+                    html = f'<span style="color: var(--primary-400);">{msg}</span>'
+
                 info = gr.HTML(html)
                 info = gr.HTML(html)
                 extensions_table = gr.HTML('Loading...')
                 extensions_table = gr.HTML('Loading...')
                 ui.load(fn=extension_table, inputs=[], outputs=[extensions_table])
                 ui.load(fn=extension_table, inputs=[], outputs=[extensions_table])
@@ -565,7 +569,7 @@ def create_ui():
                 with gr.Row():
                 with gr.Row():
                     refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
                     refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
                     extensions_index_url = os.environ.get('WEBUI_EXTENSIONS_INDEX', "https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json")
                     extensions_index_url = os.environ.get('WEBUI_EXTENSIONS_INDEX', "https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json")
-                    available_extensions_index = gr.Text(value=extensions_index_url, label="Extension index URL").style(container=False)
+                    available_extensions_index = gr.Text(value=extensions_index_url, label="Extension index URL", container=False)
                     extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
                     extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
                     install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
                     install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
 
 
@@ -574,7 +578,7 @@ def create_ui():
                     sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index")
                     sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index")
 
 
                 with gr.Row():
                 with gr.Row():
-                    search_extensions_text = gr.Text(label="Search").style(container=False)
+                    search_extensions_text = gr.Text(label="Search", container=False)
 
 
                 install_result = gr.HTML()
                 install_result = gr.HTML()
                 available_extensions_table = gr.HTML()
                 available_extensions_table = gr.HTML()

+ 30 - 44
modules/ui_extra_networks.py

@@ -2,7 +2,7 @@ import os.path
 import urllib.parse
 import urllib.parse
 from pathlib import Path
 from pathlib import Path
 
 
-from modules import shared, ui_extra_networks_user_metadata, errors
+from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks
 from modules.images import read_info_from_image, save_image_with_geninfo
 from modules.images import read_info_from_image, save_image_with_geninfo
 from modules.ui import up_down_symbol
 from modules.ui import up_down_symbol
 import gradio as gr
 import gradio as gr
@@ -101,16 +101,7 @@ class ExtraNetworksPage:
 
 
     def read_user_metadata(self, item):
     def read_user_metadata(self, item):
         filename = item.get("filename", None)
         filename = item.get("filename", None)
-        basename, ext = os.path.splitext(filename)
-        metadata_filename = basename + '.json'
-
-        metadata = {}
-        try:
-            if os.path.isfile(metadata_filename):
-                with open(metadata_filename, "r", encoding="utf8") as file:
-                    metadata = json.load(file)
-        except Exception as e:
-            errors.display(e, f"reading extra network user metadata from {metadata_filename}")
+        metadata = extra_networks.get_user_metadata(filename)
 
 
         desc = metadata.get("description", None)
         desc = metadata.get("description", None)
         if desc is not None:
         if desc is not None:
@@ -164,7 +155,7 @@ class ExtraNetworksPage:
             subdirs = {"": 1, **subdirs}
             subdirs = {"": 1, **subdirs}
 
 
         subdirs_html = "".join([f"""
         subdirs_html = "".join([f"""
-<button class='lg secondary gradio-button custom-button{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'>
+<button class='lg secondary gradio-button custom-button{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_search", event)'>
 {html.escape(subdir if subdir!="" else "all")}
 {html.escape(subdir if subdir!="" else "all")}
 </button>
 </button>
 """ for subdir in subdirs])
 """ for subdir in subdirs])
@@ -356,7 +347,7 @@ def pages_in_preferred_order(pages):
     return sorted(pages, key=lambda x: tab_scores[x.name])
     return sorted(pages, key=lambda x: tab_scores[x.name])
 
 
 
 
-def create_ui(container, button, tabname):
+def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
     ui = ExtraNetworksUi()
     ui = ExtraNetworksUi()
     ui.pages = []
     ui.pages = []
     ui.pages_contents = []
     ui.pages_contents = []
@@ -364,48 +355,42 @@ def create_ui(container, button, tabname):
     ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
     ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
     ui.tabname = tabname
     ui.tabname = tabname
 
 
-    with gr.Tabs(elem_id=tabname+"_extra_tabs"):
-        for page in ui.stored_extra_pages:
-            with gr.Tab(page.title, id=page.id_page):
-                elem_id = f"{tabname}_{page.id_page}_cards_html"
-                page_elem = gr.HTML('Loading...', elem_id=elem_id)
-                ui.pages.append(page_elem)
+    related_tabs = []
+
+    for page in ui.stored_extra_pages:
+        with gr.Tab(page.title, id=page.id_page) as tab:
+            elem_id = f"{tabname}_{page.id_page}_cards_html"
+            page_elem = gr.HTML('Loading...', elem_id=elem_id)
+            ui.pages.append(page_elem)
 
 
-                page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[])
+            page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[])
 
 
-                editor = page.create_user_metadata_editor(ui, tabname)
-                editor.create_ui()
-                ui.user_metadata_editors.append(editor)
+            editor = page.create_user_metadata_editor(ui, tabname)
+            editor.create_ui()
+            ui.user_metadata_editors.append(editor)
 
 
-    gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
-    gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", multiselect=False, visible=False, show_label=False, interactive=True)
-    ToolButton(up_down_symbol, elem_id=tabname+"_extra_sortorder")
-    button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
+            related_tabs.append(tab)
+
+    edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True)
+    dropdown_sort = gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
+    button_sortorder = ToolButton(up_down_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False)
+    button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False)
+    checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False)
 
 
     ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
     ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
     ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
     ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
 
 
-    def toggle_visibility(is_visible):
-        is_visible = not is_visible
-
-        return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary"))
+    for tab in unrelated_tabs:
+        tab.select(fn=lambda: [gr.update(visible=False) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False)
 
 
-    def fill_tabs(is_empty):
-        """Creates HTML for extra networks' tabs when the extra networks button is clicked for the first time."""
+    for tab in related_tabs:
+        tab.select(fn=lambda: [gr.update(visible=True) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False)
 
 
+    def pages_html():
         if not ui.pages_contents:
         if not ui.pages_contents:
-            refresh()
+            return refresh()
 
 
-        if is_empty:
-            return True, *ui.pages_contents
-
-        return True, *[gr.update() for _ in ui.pages_contents]
-
-    state_visible = gr.State(value=False)
-    button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button], show_progress=False)
-
-    state_empty = gr.State(value=True)
-    button.click(fn=fill_tabs, inputs=[state_empty], outputs=[state_empty, *ui.pages], show_progress=False)
+        return ui.pages_contents
 
 
     def refresh():
     def refresh():
         for pg in ui.stored_extra_pages:
         for pg in ui.stored_extra_pages:
@@ -415,6 +400,7 @@ def create_ui(container, button, tabname):
 
 
         return ui.pages_contents
         return ui.pages_contents
 
 
+    interface.load(fn=pages_html, inputs=[], outputs=[*ui.pages])
     button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
     button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
 
 
     return ui
     return ui

+ 4 - 1
modules/ui_extra_networks_checkpoints.py

@@ -3,6 +3,7 @@ import os
 
 
 from modules import shared, ui_extra_networks, sd_models
 from modules import shared, ui_extra_networks, sd_models
 from modules.ui_extra_networks import quote_js
 from modules.ui_extra_networks import quote_js
+from modules.ui_extra_networks_checkpoints_user_metadata import CheckpointUserMetadataEditor
 
 
 
 
 class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
 class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
@@ -12,7 +13,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
     def refresh(self):
     def refresh(self):
         shared.refresh_checkpoints()
         shared.refresh_checkpoints()
 
 
-    def create_item(self, name, index=None):
+    def create_item(self, name, index=None, enable_filter=True):
         checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
         checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
         path, ext = os.path.splitext(checkpoint.filename)
         path, ext = os.path.splitext(checkpoint.filename)
         return {
         return {
@@ -34,3 +35,5 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
     def allowed_directories_for_previews(self):
     def allowed_directories_for_previews(self):
         return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
         return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
 
 
+    def create_user_metadata_editor(self, ui, tabname):
+        return CheckpointUserMetadataEditor(ui, tabname, self)

+ 66 - 0
modules/ui_extra_networks_checkpoints_user_metadata.py

@@ -0,0 +1,66 @@
+import gradio as gr
+
+from modules import ui_extra_networks_user_metadata, sd_vae, shared
+from modules.ui_common import create_refresh_button
+
+
+class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor):
+    def __init__(self, ui, tabname, page):
+        super().__init__(ui, tabname, page)
+
+        self.select_vae = None
+
+    def save_user_metadata(self, name, desc, notes, vae):
+        user_metadata = self.get_user_metadata(name)
+        user_metadata["description"] = desc
+        user_metadata["notes"] = notes
+        user_metadata["vae"] = vae
+
+        self.write_user_metadata(name, user_metadata)
+
+    def update_vae(self, name):
+        if name == shared.sd_model.sd_checkpoint_info.name_for_extra:
+            sd_vae.reload_vae_weights()
+
+    def put_values_into_components(self, name):
+        user_metadata = self.get_user_metadata(name)
+        values = super().put_values_into_components(name)
+
+        return [
+            *values[0:5],
+            user_metadata.get('vae', ''),
+        ]
+
+    def create_editor(self):
+        self.create_default_editor_elems()
+
+        with gr.Row():
+            self.select_vae = gr.Dropdown(choices=["Automatic", "None"] + list(sd_vae.vae_dict), value="None", label="Preferred VAE", elem_id="checpoint_edit_user_metadata_preferred_vae")
+            create_refresh_button(self.select_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, "checpoint_edit_user_metadata_refresh_preferred_vae")
+
+        self.edit_notes = gr.TextArea(label='Notes', lines=4)
+
+        self.create_default_buttons()
+
+        viewed_components = [
+            self.edit_name,
+            self.edit_description,
+            self.html_filedata,
+            self.html_preview,
+            self.edit_notes,
+            self.select_vae,
+        ]
+
+        self.button_edit\
+            .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\
+            .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])
+
+        edited_components = [
+            self.edit_description,
+            self.edit_notes,
+            self.select_vae,
+        ]
+
+        self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components)
+        self.button_save.click(fn=self.update_vae, inputs=[self.edit_name_input])
+

+ 1 - 1
modules/ui_extra_networks_hypernets.py

@@ -11,7 +11,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
     def refresh(self):
     def refresh(self):
         shared.reload_hypernetworks()
         shared.reload_hypernetworks()
 
 
-    def create_item(self, name, index=None):
+    def create_item(self, name, index=None, enable_filter=True):
         full_path = shared.hypernetworks[name]
         full_path = shared.hypernetworks[name]
         path, ext = os.path.splitext(full_path)
         path, ext = os.path.splitext(full_path)
 
 

+ 1 - 1
modules/ui_extra_networks_textual_inversion.py

@@ -12,7 +12,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
     def refresh(self):
     def refresh(self):
         sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
         sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
 
 
-    def create_item(self, name, index=None):
+    def create_item(self, name, index=None, enable_filter=True):
         embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
         embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
 
 
         path, ext = os.path.splitext(embedding.filename)
         path, ext = os.path.splitext(embedding.filename)

+ 1 - 1
modules/ui_postprocessing.py

@@ -6,7 +6,7 @@ import modules.generation_parameters_copypaste as parameters_copypaste
 def create_ui():
 def create_ui():
     tab_index = gr.State(value=0)
     tab_index = gr.State(value=0)
 
 
-    with gr.Row().style(equal_height=False, variant='compact'):
+    with gr.Row(equal_height=False, variant='compact'):
         with gr.Column(variant='compact'):
         with gr.Column(variant='compact'):
             with gr.Tabs(elem_id="mode_extras"):
             with gr.Tabs(elem_id="mode_extras"):
                 with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:
                 with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:

+ 110 - 0
modules/ui_prompt_styles.py

@@ -0,0 +1,110 @@
+import gradio as gr
+
+from modules import shared, ui_common, ui_components, styles
+
+styles_edit_symbol = '\U0001f58c\uFE0F'  # 🖌️
+styles_materialize_symbol = '\U0001f4cb'  # 📋
+
+
+def select_style(name):
+    style = shared.prompt_styles.styles.get(name)
+    existing = style is not None
+    empty = not name
+
+    prompt = style.prompt if style else gr.update()
+    negative_prompt = style.negative_prompt if style else gr.update()
+
+    return prompt, negative_prompt, gr.update(visible=existing), gr.update(visible=not empty)
+
+
+def save_style(name, prompt, negative_prompt):
+    if not name:
+        return gr.update(visible=False)
+
+    style = styles.PromptStyle(name, prompt, negative_prompt)
+    shared.prompt_styles.styles[style.name] = style
+    shared.prompt_styles.save_styles(shared.styles_filename)
+
+    return gr.update(visible=True)
+
+
+def delete_style(name):
+    if name == "":
+        return
+
+    shared.prompt_styles.styles.pop(name, None)
+    shared.prompt_styles.save_styles(shared.styles_filename)
+
+    return '', '', ''
+
+
+def materialize_styles(prompt, negative_prompt, styles):
+    prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
+    negative_prompt = shared.prompt_styles.apply_negative_styles_to_prompt(negative_prompt, styles)
+
+    return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=negative_prompt), gr.Dropdown.update(value=[])]
+
+
+def refresh_styles():
+    return gr.update(choices=list(shared.prompt_styles.styles)), gr.update(choices=list(shared.prompt_styles.styles))
+
+
+class UiPromptStyles:
+    def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt):
+        self.tabname = tabname
+
+        with gr.Row(elem_id=f"{tabname}_styles_row"):
+            self.dropdown = gr.Dropdown(label="Styles", show_label=False, elem_id=f"{tabname}_styles", choices=list(shared.prompt_styles.styles), value=[], multiselect=True, tooltip="Styles")
+            edit_button = ui_components.ToolButton(value=styles_edit_symbol, elem_id=f"{tabname}_styles_edit_button", tooltip="Edit styles")
+
+        with gr.Box(elem_id=f"{tabname}_styles_dialog", elem_classes="popup-dialog") as styles_dialog:
+            with gr.Row():
+                self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.")
+                ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles")
+                self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.")
+
+            with gr.Row():
+                self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3)
+
+            with gr.Row():
+                self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3)
+
+            with gr.Row():
+                self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False)
+                self.delete = gr.Button('Delete', variant='primary', elem_id=f'{tabname}_edit_style_delete', visible=False)
+                self.close = gr.Button('Close', variant='secondary', elem_id=f'{tabname}_edit_style_close')
+
+        self.selection.change(
+            fn=select_style,
+            inputs=[self.selection],
+            outputs=[self.prompt, self.neg_prompt, self.delete, self.save],
+            show_progress=False,
+        )
+
+        self.save.click(
+            fn=save_style,
+            inputs=[self.selection, self.prompt, self.neg_prompt],
+            outputs=[self.delete],
+            show_progress=False,
+        ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)
+
+        self.delete.click(
+            fn=delete_style,
+            _js='function(name){ if(name == "") return ""; return confirm("Delete style " + name + "?") ? name : ""; }',
+            inputs=[self.selection],
+            outputs=[self.selection, self.prompt, self.neg_prompt],
+            show_progress=False,
+        ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)
+
+        self.materialize.click(
+            fn=materialize_styles,
+            inputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
+            outputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
+            show_progress=False,
+        ).then(fn=None, _js="function(){update_"+tabname+"_tokens(); closePopup();}", show_progress=False)
+
+        ui_common.setup_dialog(button_show=edit_button, dialog=styles_dialog, button_close=self.close)
+
+
+
+

+ 1 - 1
modules/ui_settings.py

@@ -158,7 +158,7 @@ class UiSettings:
                     loadsave.create_ui()
                     loadsave.create_ui()
 
 
                 with gr.TabItem("Sysinfo", id="sysinfo", elem_id="settings_tab_sysinfo"):
                 with gr.TabItem("Sysinfo", id="sysinfo", elem_id="settings_tab_sysinfo"):
-                    gr.HTML('<a href="./internal/sysinfo-download" class="sysinfo_big_link" download>Download system info</a><br /><a href="./internal/sysinfo">(or open as text in a new page)</a>', elem_id="sysinfo_download")
+                    gr.HTML('<a href="./internal/sysinfo-download" class="sysinfo_big_link" download>Download system info</a><br /><a href="./internal/sysinfo" target="_blank">(or open as text in a new page)</a>', elem_id="sysinfo_download")
 
 
                     with gr.Row():
                     with gr.Row():
                         with gr.Column(scale=1):
                         with gr.Column(scale=1):

+ 2 - 2
requirements.txt

@@ -7,7 +7,7 @@ blendmodes
 clean-fid
 clean-fid
 einops
 einops
 gfpgan
 gfpgan
-gradio==3.32.0
+gradio==3.39.0
 inflection
 inflection
 jsonmerge
 jsonmerge
 kornia
 kornia
@@ -30,4 +30,4 @@ tomesd
 torch
 torch
 torchdiffeq
 torchdiffeq
 torchsde
 torchsde
-transformers==4.25.1
+transformers==4.30.2

+ 7 - 7
requirements_versions.txt

@@ -1,13 +1,13 @@
-GitPython==3.1.30
+GitPython==3.1.32
 Pillow==9.5.0
 Pillow==9.5.0
-accelerate==0.18.0
+accelerate==0.21.0
 basicsr==1.4.2
 basicsr==1.4.2
 blendmodes==2022
 blendmodes==2022
 clean-fid==0.1.35
 clean-fid==0.1.35
 einops==0.4.1
 einops==0.4.1
 fastapi==0.94.0
 fastapi==0.94.0
 gfpgan==1.3.8
 gfpgan==1.3.8
-gradio==3.32.0
+gradio==3.39.0
 httpcore==0.15
 httpcore==0.15
 inflection==0.5.1
 inflection==0.5.1
 jsonmerge==1.8.0
 jsonmerge==1.8.0
@@ -22,10 +22,10 @@ pytorch_lightning==1.9.4
 realesrgan==0.3.0
 realesrgan==0.3.0
 resize-right==0.0.2
 resize-right==0.0.2
 safetensors==0.3.1
 safetensors==0.3.1
-scikit-image==0.20.0
-timm==0.6.7
-tomesd==0.1.2
+scikit-image==0.21.0
+timm==0.9.2
+tomesd==0.1.3
 torch
 torch
 torchdiffeq==0.2.3
 torchdiffeq==0.2.3
 torchsde==0.2.5
 torchsde==0.2.5
-transformers==4.25.1
+transformers==4.30.2

+ 3 - 10
scripts/xyz_grid.py

@@ -67,14 +67,6 @@ def apply_order(p, x, xs):
     p.prompt = prompt_tmp + p.prompt
     p.prompt = prompt_tmp + p.prompt
 
 
 
 
-def apply_sampler(p, x, xs):
-    sampler_name = sd_samplers.samplers_map.get(x.lower(), None)
-    if sampler_name is None:
-        raise RuntimeError(f"Unknown sampler: {x}")
-
-    p.sampler_name = sampler_name
-
-
 def confirm_samplers(p, xs):
 def confirm_samplers(p, xs):
     for x in xs:
     for x in xs:
         if x.lower() not in sd_samplers.samplers_map:
         if x.lower() not in sd_samplers.samplers_map:
@@ -224,8 +216,9 @@ axis_options = [
     AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")),
     AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")),
     AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),
     AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),
     AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
     AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
-    AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
-    AxisOptionImg2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
+    AxisOptionTxt2Img("Sampler", str, apply_field("sampler_name"), format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
+    AxisOptionTxt2Img("Hires sampler", str, apply_field("hr_sampler_name"), confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
+    AxisOptionImg2Img("Sampler", str, apply_field("sampler_name"), format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
     AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_remove_path, confirm=confirm_checkpoints, cost=1.0, choices=lambda: sorted(sd_models.checkpoints_list, key=str.casefold)),
     AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_remove_path, confirm=confirm_checkpoints, cost=1.0, choices=lambda: sorted(sd_models.checkpoints_list, key=str.casefold)),
     AxisOption("Negative Guidance minimum sigma", float, apply_field("s_min_uncond")),
     AxisOption("Negative Guidance minimum sigma", float, apply_field("s_min_uncond")),
     AxisOption("Sigma Churn", float, apply_field("s_churn")),
     AxisOption("Sigma Churn", float, apply_field("s_churn")),

+ 45 - 6
style.css

@@ -8,6 +8,7 @@
     --checkbox-label-gap: 0.25em 0.1em;
     --checkbox-label-gap: 0.25em 0.1em;
     --section-header-text-size: 12pt;
     --section-header-text-size: 12pt;
     --block-background-fill: transparent;
     --block-background-fill: transparent;
+
 }
 }
 
 
 .block.padded:not(.gradio-accordion) {
 .block.padded:not(.gradio-accordion) {
@@ -42,7 +43,8 @@ div.form{
 .block.gradio-radio,
 .block.gradio-radio,
 .block.gradio-checkboxgroup,
 .block.gradio-checkboxgroup,
 .block.gradio-number,
 .block.gradio-number,
-.block.gradio-colorpicker
+.block.gradio-colorpicker,
+div.gradio-group
 {
 {
     border-width: 0 !important;
     border-width: 0 !important;
     box-shadow: none !important;
     box-shadow: none !important;
@@ -133,6 +135,15 @@ a{
     cursor: pointer;
     cursor: pointer;
 }
 }
 
 
+div.styler{
+    border: none;
+    background: var(--background-fill-primary);
+}
+
+.block.gradio-textbox{
+    overflow: visible !important;
+}
+
 
 
 /* general styled components */
 /* general styled components */
 
 
@@ -164,7 +175,7 @@ a{
 .checkboxes-row > div{
 .checkboxes-row > div{
     flex: 0;
     flex: 0;
     white-space: nowrap;
     white-space: nowrap;
-    min-width: auto;
+    min-width: auto !important;
 }
 }
 
 
 button.custom-button{
 button.custom-button{
@@ -388,6 +399,7 @@ div#extras_scale_to_tab div.form{
 #quicksettings > div, #quicksettings > fieldset{
 #quicksettings > div, #quicksettings > fieldset{
     max-width: 24em;
     max-width: 24em;
     min-width: 24em;
     min-width: 24em;
+    width: 24em;
     padding: 0;
     padding: 0;
     border: none;
     border: none;
     box-shadow: none;
     box-shadow: none;
@@ -482,6 +494,13 @@ table.popup-table .link{
     font-size: 18pt;
     font-size: 18pt;
 }
 }
 
 
+#settings .settings-info{
+    max-width: 48em;
+    border: 1px dotted #777;
+    margin: 0;
+    padding: 1em;
+}
+
 
 
 /* live preview */
 /* live preview */
 .progressDiv{
 .progressDiv{
@@ -767,9 +786,14 @@ footer {
 /* extra networks UI */
 /* extra networks UI */
 
 
 .extra-network-cards{
 .extra-network-cards{
-    height: 725px;
-    overflow: scroll;
+    height: calc(100vh - 24rem);
+    overflow: clip scroll;
     resize: vertical;
     resize: vertical;
+    min-height: 52rem;
+}
+
+.extra-networks > div.tab-nav{
+    min-height: 3.4rem;
 }
 }
 
 
 .extra-networks > div > [id *= '_extra_']{
 .extra-networks > div > [id *= '_extra_']{
@@ -784,10 +808,12 @@ footer {
     margin: 0 0.15em;
     margin: 0 0.15em;
 }
 }
 .extra-networks .tab-nav .search,
 .extra-networks .tab-nav .search,
-.extra-networks .tab-nav .sort{
-    display: inline-block;
+.extra-networks .tab-nav .sort,
+.extra-networks .tab-nav .show-dirs
+{
     margin: 0.3em;
     margin: 0.3em;
     align-self: center;
     align-self: center;
+    width: auto;
 }
 }
 
 
 .extra-networks .tab-nav .search {
 .extra-networks .tab-nav .search {
@@ -972,3 +998,16 @@ div.block.gradio-box.edit-user-metadata {
 .edit-user-metadata-buttons{
 .edit-user-metadata-buttons{
     margin-top: 1.5em;
     margin-top: 1.5em;
 }
 }
+
+
+
+
+div.block.gradio-box.popup-dialog, .popup-dialog {
+    width: 56em;
+    background: var(--body-background-fill);
+    padding: 2em !important;
+}
+
+div.block.gradio-box.popup-dialog > div:last-child, .popup-dialog > div:last-child{
+    margin-top: 1em;
+}

+ 10 - 36
webui.py

@@ -14,7 +14,6 @@ from typing import Iterable
 from fastapi import FastAPI
 from fastapi import FastAPI
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.gzip import GZipMiddleware
 from fastapi.middleware.gzip import GZipMiddleware
-from packaging import version
 
 
 import logging
 import logging
 
 
@@ -50,6 +49,7 @@ startup_timer.record("setup paths")
 import ldm.modules.encoders.modules  # noqa: F401
 import ldm.modules.encoders.modules  # noqa: F401
 startup_timer.record("import ldm")
 startup_timer.record("import ldm")
 
 
+
 from modules import extra_networks
 from modules import extra_networks
 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock  # noqa: F401
 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock  # noqa: F401
 
 
@@ -58,10 +58,15 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__:
     torch.__long_version__ = torch.__version__
     torch.__long_version__ = torch.__version__
     torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
     torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
 
 
-from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
+from modules import shared
+
+if not shared.cmd_opts.skip_version_check:
+    errors.check_versions()
+
 import modules.codeformer_model as codeformer
 import modules.codeformer_model as codeformer
-import modules.face_restoration
 import modules.gfpgan_model as gfpgan
 import modules.gfpgan_model as gfpgan
+from modules import sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
+import modules.face_restoration
 import modules.img2img
 import modules.img2img
 
 
 import modules.lowvram
 import modules.lowvram
@@ -130,37 +135,6 @@ def fix_asyncio_event_loop_policy():
     asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
     asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
 
 
 
 
-def check_versions():
-    if shared.cmd_opts.skip_version_check:
-        return
-
-    expected_torch_version = "2.0.0"
-
-    if version.parse(torch.__version__) < version.parse(expected_torch_version):
-        errors.print_error_explanation(f"""
-You are running torch {torch.__version__}.
-The program is tested to work with torch {expected_torch_version}.
-To reinstall the desired version, run with commandline flag --reinstall-torch.
-Beware that this will cause a lot of large files to be downloaded, as well as
-there are reports of issues with training tab on the latest version.
-
-Use --skip-version-check commandline argument to disable this check.
-        """.strip())
-
-    expected_xformers_version = "0.0.20"
-    if shared.xformers_available:
-        import xformers
-
-        if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
-            errors.print_error_explanation(f"""
-You are running xformers {xformers.__version__}.
-The program is tested to work with xformers {expected_xformers_version}.
-To reinstall the desired version, run with commandline flag --reinstall-xformers.
-
-Use --skip-version-check commandline argument to disable this check.
-            """.strip())
-
-
 def restore_config_state_file():
 def restore_config_state_file():
     config_state_file = shared.opts.restore_config_state_file
     config_state_file = shared.opts.restore_config_state_file
     if config_state_file == "":
     if config_state_file == "":
@@ -237,7 +211,7 @@ def configure_sigint_handler():
 def configure_opts_onchange():
 def configure_opts_onchange():
     shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
     shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
     shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
     shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
-    shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
+    shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
     shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
     shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
     shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
     shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
     shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
     shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
@@ -248,7 +222,6 @@ def initialize():
     fix_asyncio_event_loop_policy()
     fix_asyncio_event_loop_policy()
     validate_tls_options()
     validate_tls_options()
     configure_sigint_handler()
     configure_sigint_handler()
-    check_versions()
     modelloader.cleanup_models()
     modelloader.cleanup_models()
     configure_opts_onchange()
     configure_opts_onchange()
 
 
@@ -368,6 +341,7 @@ def api_only():
     setup_middleware(app)
     setup_middleware(app)
     api = create_api(app)
     api = create_api(app)
 
 
+    modules.script_callbacks.before_ui_callback()
     modules.script_callbacks.app_started_callback(None, app)
     modules.script_callbacks.app_started_callback(None, app)
 
 
     print(f"Startup time: {startup_timer.summary()}.")
     print(f"Startup time: {startup_timer.summary()}.")