Browse Source

Merge branch 'dev' into fix-Hypertile-xyz

AUTOMATIC1111 1 năm trước cách đây
mục cha
commit
96f907ee09

+ 5 - 0
CHANGELOG.md

@@ -1,3 +1,8 @@
+## 1.9.4
+
+### Bug Fixes:
+*  pin setuptools version to fix the startup error ([#15882](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15882)) 
+
 ## 1.9.3
 
 ### Bug Fixes:

+ 30 - 10
extensions-builtin/Lora/networks.py

@@ -260,6 +260,16 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
 
     loaded_networks.clear()
 
+    unavailable_networks = []
+    for name in names:
+        if name.lower() in forbidden_network_aliases and available_networks.get(name) is None:
+            unavailable_networks.append(name)
+        elif available_network_aliases.get(name) is None:
+            unavailable_networks.append(name)
+
+    if unavailable_networks:
+        update_available_networks_by_names(unavailable_networks)
+
     networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
     if any(x is None for x in networks_on_disk):
         list_available_networks()
@@ -566,22 +576,16 @@ def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
     return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs)
 
 
-def list_available_networks():
-    available_networks.clear()
-    available_network_aliases.clear()
-    forbidden_network_aliases.clear()
-    available_network_hash_lookup.clear()
-    forbidden_network_aliases.update({"none": 1, "Addams": 1})
-
-    os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
-
+def process_network_files(names: list[str] | None = None):
     candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
     candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir_backcompat, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
     for filename in candidates:
         if os.path.isdir(filename):
             continue
-
         name = os.path.splitext(os.path.basename(filename))[0]
+        # if names is provided, only load networks with names in the list
+        if names and name not in names:
+            continue
         try:
             entry = network.NetworkOnDisk(name, filename)
         except OSError:  # should catch FileNotFoundError and PermissionError etc.
@@ -597,6 +601,22 @@ def list_available_networks():
         available_network_aliases[entry.alias] = entry
 
 
+def update_available_networks_by_names(names: list[str]):
+    process_network_files(names)
+
+
+def list_available_networks():
+    available_networks.clear()
+    available_network_aliases.clear()
+    forbidden_network_aliases.clear()
+    available_network_hash_lookup.clear()
+    forbidden_network_aliases.update({"none": 1, "Addams": 1})
+
+    os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
+
+    process_network_files()
+
+
 re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
 
 

+ 1 - 1
extensions-builtin/Lora/ui_extra_networks_lora.py

@@ -60,7 +60,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
         else:
             sd_version = lora_on_disk.sd_version
 
-        if shared.opts.lora_show_all or not enable_filter:
+        if shared.opts.lora_show_all or not enable_filter or not shared.sd_model:
             pass
         elif sd_version == network.SdVersion.Unknown:
             model_version = network.SdVersion.SDXL if shared.sd_model.is_sdxl else network.SdVersion.SD2 if shared.sd_model.is_sd2 else network.SdVersion.SD1

+ 1 - 1
javascript/ui.js

@@ -337,8 +337,8 @@ onOptionsChanged(function() {
 let txt2img_textarea, img2img_textarea = undefined;
 
 function restart_reload() {
+    document.body.style.backgroundColor = "var(--background-fill-primary)";
     document.body.innerHTML = '<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
-
     var requestPing = function() {
         requestGet("./internal/ping", {}, function(data) {
             location.reload();

+ 10 - 2
modules/api/api.py

@@ -438,15 +438,19 @@ class Api:
         self.apply_infotext(txt2imgreq, "txt2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
 
         selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)
+        sampler, scheduler = sd_samplers.get_sampler_and_scheduler(txt2imgreq.sampler_name or txt2imgreq.sampler_index, txt2imgreq.scheduler)
 
         populate = txt2imgreq.copy(update={  # Override __init__ params
-            "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
+            "sampler_name": validate_sampler_name(sampler),
             "do_not_save_samples": not txt2imgreq.save_images,
             "do_not_save_grid": not txt2imgreq.save_images,
         })
         if populate.sampler_name:
             populate.sampler_index = None  # prevent a warning later on
 
+        if not populate.scheduler and scheduler != "Automatic":
+            populate.scheduler = scheduler
+
         args = vars(populate)
         args.pop('script_name', None)
         args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
@@ -502,9 +506,10 @@ class Api:
         self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
 
         selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
+        sampler, scheduler = sd_samplers.get_sampler_and_scheduler(img2imgreq.sampler_name or img2imgreq.sampler_index, img2imgreq.scheduler)
 
         populate = img2imgreq.copy(update={  # Override __init__ params
-            "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
+            "sampler_name": validate_sampler_name(sampler),
             "do_not_save_samples": not img2imgreq.save_images,
             "do_not_save_grid": not img2imgreq.save_images,
             "mask": mask,
@@ -512,6 +517,9 @@ class Api:
         if populate.sampler_name:
             populate.sampler_index = None  # prevent a warning later on
 
+        if not populate.scheduler and scheduler != "Automatic":
+            populate.scheduler = scheduler
+
         args = vars(populate)
         args.pop('include_init_images', None)  # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
         args.pop('script_name', None)

+ 1 - 1
modules/cmd_args.py

@@ -41,7 +41,7 @@ parser.add_argument("--lowvram", action='store_true', help="enable stable diffus
 parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
 parser.add_argument("--always-batch-cond-uncond", action='store_true', help="does not do anything")
 parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
-parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
+parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "half", "autocast"], default="autocast")
 parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
 parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
 parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)

+ 24 - 0
modules/devices.py

@@ -114,6 +114,9 @@ errors.run(enable_tf32, "Enabling TF32")
 
 cpu: torch.device = torch.device("cpu")
 fp8: bool = False
+# Force fp16 for all models in inference. No casting during inference.
+# This flag is controlled by "--precision half" command line arg.
+force_fp16: bool = False
 device: torch.device = None
 device_interrogate: torch.device = None
 device_gfpgan: torch.device = None
@@ -127,6 +130,8 @@ unet_needs_upcast = False
 
 
 def cond_cast_unet(input):
+    if force_fp16:
+        return input.to(torch.float16)
     return input.to(dtype_unet) if unet_needs_upcast else input
 
 
@@ -206,6 +211,11 @@ def autocast(disable=False):
     if disable:
         return contextlib.nullcontext()
 
+    if force_fp16:
+        # No casting during inference if force_fp16 is enabled.
+        # All tensor dtype conversion happens before inference.
+        return contextlib.nullcontext()
+
     if fp8 and device==cpu:
         return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
 
@@ -269,3 +279,17 @@ def first_time_calculation():
     x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
     conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
     conv2d(x)
+
+
+def force_model_fp16():
+    """
+    ldm and sgm has modules.diffusionmodules.util.GroupNorm32.forward, which
+    force conversion of input to float32. If force_fp16 is enabled, we need to
+    prevent this casting.
+    """
+    assert force_fp16
+    import sgm.modules.diffusionmodules.util as sgm_util
+    import ldm.modules.diffusionmodules.util as ldm_util
+    sgm_util.GroupNorm32 = torch.nn.GroupNorm
+    ldm_util.GroupNorm32 = torch.nn.GroupNorm
+    print("ldm/sgm GroupNorm32 replaced with normal torch.nn.GroupNorm due to `--precision half`.")

+ 1 - 1
modules/images.py

@@ -653,7 +653,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
     # WebP and JPG formats have maximum dimension limits of 16383 and 65535 respectively. switch to PNG which has a much higher limit
     if (image.height > 65535 or image.width > 65535) and extension.lower() in ("jpg", "jpeg") or (image.height > 16383 or image.width > 16383) and extension.lower() == "webp":
         print('Image dimensions too large; saving as PNG')
-        extension = ".png"
+        extension = "png"
 
     if save_to_dirs is None:
         save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)

+ 12 - 7
modules/processing.py

@@ -238,11 +238,6 @@ class StableDiffusionProcessing:
             self.styles = []
 
         self.sampler_noise_scheduler_override = None
-        self.s_min_uncond = self.s_min_uncond if self.s_min_uncond is not None else opts.s_min_uncond
-        self.s_churn = self.s_churn if self.s_churn is not None else opts.s_churn
-        self.s_tmin = self.s_tmin if self.s_tmin is not None else opts.s_tmin
-        self.s_tmax = (self.s_tmax if self.s_tmax is not None else opts.s_tmax) or float('inf')
-        self.s_noise = self.s_noise if self.s_noise is not None else opts.s_noise
 
         self.extra_generation_params = self.extra_generation_params or {}
         self.override_settings = self.override_settings or {}
@@ -259,6 +254,13 @@ class StableDiffusionProcessing:
         self.cached_uc = StableDiffusionProcessing.cached_uc
         self.cached_c = StableDiffusionProcessing.cached_c
 
+    def fill_fields_from_opts(self):
+        self.s_min_uncond = self.s_min_uncond if self.s_min_uncond is not None else opts.s_min_uncond
+        self.s_churn = self.s_churn if self.s_churn is not None else opts.s_churn
+        self.s_tmin = self.s_tmin if self.s_tmin is not None else opts.s_tmin
+        self.s_tmax = (self.s_tmax if self.s_tmax is not None else opts.s_tmax) or float('inf')
+        self.s_noise = self.s_noise if self.s_noise is not None else opts.s_noise
+
     @property
     def sd_model(self):
         return shared.sd_model
@@ -569,7 +571,7 @@ class Processed:
         self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
         self.all_seeds = all_seeds or p.all_seeds or [self.seed]
         self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
-        self.infotexts = infotexts or [info]
+        self.infotexts = infotexts or [info] * len(images_list)
         self.version = program_version()
 
     def js(self):
@@ -794,7 +796,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
         "Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
         "Init image hash": getattr(p, 'init_img_hash', None),
         "RNG": opts.randn_source if opts.randn_source != "GPU" else None,
-        "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
         "Tiling": "True" if p.tiling else None,
         **p.extra_generation_params,
         "Version": program_version() if opts.add_version_to_infotext else None,
@@ -842,6 +843,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
 
         sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
 
+        # backwards compatibility, fix sampler and scheduler if invalid
+        sd_samplers.fix_p_invalid_sampler_and_scheduler(p)
+
         res = process_images_inner(p)
 
     finally:
@@ -890,6 +894,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
     modules.sd_hijack.model_hijack.apply_circular(p.tiling)
     modules.sd_hijack.model_hijack.clear_comments()
 
+    p.fill_fields_from_opts()
     p.setup_prompts()
 
     if isinstance(seed, list):

+ 4 - 2
modules/sd_hijack_optimizations.py

@@ -486,7 +486,8 @@ def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):
     k_in = self.to_k(context_k)
     v_in = self.to_v(context_v)
 
-    q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
+    q, k, v = (t.reshape(t.shape[0], t.shape[1], h, -1) for t in (q_in, k_in, v_in))
+
     del q_in, k_in, v_in
 
     dtype = q.dtype
@@ -497,7 +498,8 @@ def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):
 
     out = out.to(dtype)
 
-    out = rearrange(out, 'b n h d -> b n (h d)', h=h)
+    b, n, h, d = out.shape
+    out = out.reshape(b, n, h * d)
     return self.to_out(out)
 
 

+ 74 - 5
modules/sd_hijack_unet.py

@@ -1,5 +1,7 @@
 import torch
 from packaging import version
+from einops import repeat
+import math
 
 from modules import devices
 from modules.sd_hijack_utils import CondFunc
@@ -36,7 +38,7 @@ th = TorchHijackForUnet()
 
 # Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
 def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
-
+    """Always make sure inputs to unet are in correct dtype."""
     if isinstance(cond, dict):
         for y in cond.keys():
             if isinstance(cond[y], list):
@@ -45,7 +47,59 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
                 cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]
 
     with devices.autocast():
-        return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
+        result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs)
+        if devices.unet_needs_upcast:
+            return result.float()
+        else:
+            return result
+
+
+# Monkey patch to create timestep embed tensor on device, avoiding a block.
+def timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False):
+    """
+    Create sinusoidal timestep embeddings.
+    :param timesteps: a 1-D Tensor of N indices, one per batch element.
+                      These may be fractional.
+    :param dim: the dimension of the output.
+    :param max_period: controls the minimum frequency of the embeddings.
+    :return: an [N x dim] Tensor of positional embeddings.
+    """
+    if not repeat_only:
+        half = dim // 2
+        freqs = torch.exp(
+            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
+        )
+        args = timesteps[:, None].float() * freqs[None]
+        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+        if dim % 2:
+            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+    else:
+        embedding = repeat(timesteps, 'b -> b d', d=dim)
+    return embedding
+
+
+# Monkey patch to SpatialTransformer removing unnecessary contiguous calls.
+# Prevents a lot of unnecessary aten::copy_ calls
+def spatial_transformer_forward(_, self, x: torch.Tensor, context=None):
+    # note: if no context is given, cross-attention defaults to self-attention
+    if not isinstance(context, list):
+        context = [context]
+    b, c, h, w = x.shape
+    x_in = x
+    x = self.norm(x)
+    if not self.use_linear:
+        x = self.proj_in(x)
+    x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
+    if self.use_linear:
+        x = self.proj_in(x)
+    for i, block in enumerate(self.transformer_blocks):
+        x = block(x, context=context[i])
+    if self.use_linear:
+        x = self.proj_out(x)
+    x = x.view(b, h, w, c).permute(0, 3, 1, 2)
+    if not self.use_linear:
+        x = self.proj_out(x)
+    return x + x_in
 
 
 class GELUHijack(torch.nn.GELU, torch.nn.Module):
@@ -64,12 +118,15 @@ def hijack_ddpm_edit():
     if not ddpm_edit_hijack:
         CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
         CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
-        ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
+        ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model)
 
 
 unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
 CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
+CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding)
+CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward)
 CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
+
 if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
     CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
     CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
@@ -81,5 +138,17 @@ CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_s
 CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
 CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
 
-CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model, unet_needs_upcast)
-CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
+CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model)
+CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model)
+
+
+def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs):
+    if devices.unet_needs_upcast and timesteps.dtype == torch.int64:
+        dtype = torch.float32
+    else:
+        dtype = devices.dtype_unet
+    return orig_func(timesteps, *args, **kwargs).to(dtype=dtype)
+
+
+CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
+CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)

+ 15 - 11
modules/sd_hijack_utils.py

@@ -1,7 +1,11 @@
 import importlib
 
+
+always_true_func = lambda *args, **kwargs: True
+
+
 class CondFunc:
-    def __new__(cls, orig_func, sub_func, cond_func):
+    def __new__(cls, orig_func, sub_func, cond_func=always_true_func):
         self = super(CondFunc, cls).__new__(cls)
         if isinstance(orig_func, str):
             func_path = orig_func.split('.')
@@ -20,13 +24,13 @@ class CondFunc:
                 print(f"Warning: Failed to resolve {orig_func} for CondFunc hijack")
                 pass
         self.__init__(orig_func, sub_func, cond_func)
-        return lambda *args, **kwargs: self(*args, **kwargs)
-    def __init__(self, orig_func, sub_func, cond_func):
-        self.__orig_func = orig_func
-        self.__sub_func = sub_func
-        self.__cond_func = cond_func
-    def __call__(self, *args, **kwargs):
-        if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
-            return self.__sub_func(self.__orig_func, *args, **kwargs)
-        else:
-            return self.__orig_func(*args, **kwargs)
+        return lambda *args, **kwargs: self(*args, **kwargs)
+    def __init__(self, orig_func, sub_func, cond_func):
+        self.__orig_func = orig_func
+        self.__sub_func = sub_func
+        self.__cond_func = cond_func
+    def __call__(self, *args, **kwargs):
+        if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
+            return self.__sub_func(self.__orig_func, *args, **kwargs)
+        else:
+            return self.__orig_func(*args, **kwargs)

+ 7 - 5
modules/sd_models.py

@@ -403,6 +403,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
         model.float()
         model.alphas_cumprod_original = model.alphas_cumprod
         devices.dtype_unet = torch.float32
+        assert shared.cmd_opts.precision != "half", "Cannot use --precision half with --no-half"
         timer.record("apply float()")
     else:
         vae = model.first_stage_model
@@ -540,7 +541,7 @@ def repair_config(sd_config):
     if hasattr(sd_config.model.params, 'unet_config'):
         if shared.cmd_opts.no_half:
             sd_config.model.params.unet_config.params.use_fp16 = False
-        elif shared.cmd_opts.upcast_sampling:
+        elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half":
             sd_config.model.params.unet_config.params.use_fp16 = True
 
     if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
@@ -659,10 +660,11 @@ def get_empty_cond(sd_model):
 
 
 def send_model_to_cpu(m):
-    if m.lowvram:
-        lowvram.send_everything_to_cpu()
-    else:
-        m.to(devices.cpu)
+    if m is not None:
+        if m.lowvram:
+            lowvram.send_everything_to_cpu()
+        else:
+            m.to(devices.cpu)
 
     devices.torch_gc()
 

+ 8 - 1
modules/sd_samplers.py

@@ -1,7 +1,7 @@
 from __future__ import annotations
 
 import functools
-
+import logging
 from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared, sd_samplers_common, sd_schedulers
 
 # imports for functions that previously were here and are used by other modules
@@ -122,4 +122,11 @@ def get_sampler_and_scheduler(sampler_name, scheduler_name):
     return sampler.name, found_scheduler.label
 
 
+def fix_p_invalid_sampler_and_scheduler(p):
+    i_sampler_name, i_scheduler = p.sampler_name, p.scheduler
+    p.sampler_name, p.scheduler = get_sampler_and_scheduler(p.sampler_name, p.scheduler)
+    if p.sampler_name != i_sampler_name or i_scheduler != p.scheduler:
+        logging.warning(f'Sampler Scheduler autocorrection: "{i_sampler_name}" -> "{p.sampler_name}", "{i_scheduler}" -> "{p.scheduler}"')
+
+
 set_samplers()

+ 9 - 2
modules/sd_samplers_cfg_denoiser.py

@@ -212,9 +212,16 @@ class CFGDenoiser(torch.nn.Module):
         uncond = denoiser_params.text_uncond
         skip_uncond = False
 
-        # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
-        if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
+        if shared.opts.skip_early_cond != 0. and self.step / self.total_steps <= shared.opts.skip_early_cond:
             skip_uncond = True
+            self.p.extra_generation_params["Skip Early CFG"] = shared.opts.skip_early_cond
+        elif (self.step % 2 or shared.opts.s_min_uncond_all) and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
+            skip_uncond = True
+            self.p.extra_generation_params["NGMS"] = s_min_uncond
+            if shared.opts.s_min_uncond_all:
+                self.p.extra_generation_params["NGMS all steps"] = shared.opts.s_min_uncond_all
+
+        if skip_uncond:
             x_in = x_in[:-batch_size]
             sigma_in = sigma_in[:-batch_size]
 

+ 40 - 0
modules/sd_schedulers.py

@@ -4,6 +4,9 @@ import torch
 
 import k_diffusion
 
+import numpy as np
+
+from modules import shared
 
 @dataclasses.dataclass
 class Scheduler:
@@ -30,6 +33,41 @@ def sgm_uniform(n, sigma_min, sigma_max, inner_model, device):
     sigs += [0.0]
     return torch.FloatTensor(sigs).to(device)
 
+def get_align_your_steps_sigmas(n, sigma_min, sigma_max, device='cpu'):
+    # https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
+    def loglinear_interp(t_steps, num_steps):
+        """
+        Performs log-linear interpolation of a given array of decreasing numbers.
+        """
+        xs = np.linspace(0, 1, len(t_steps))
+        ys = np.log(t_steps[::-1])
+
+        new_xs = np.linspace(0, 1, num_steps)
+        new_ys = np.interp(new_xs, xs, ys)
+
+        interped_ys = np.exp(new_ys)[::-1].copy()
+        return interped_ys
+
+    if shared.sd_model.is_sdxl:
+        sigmas = [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.029]
+    else:
+        # Default to SD 1.5 sigmas.
+        sigmas = [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.029]
+
+    if n != len(sigmas):
+        sigmas = np.append(loglinear_interp(sigmas, n), [0.0])
+    else:
+        sigmas.append(0.0)
+
+    return torch.FloatTensor(sigmas).to(device)
+
+def kl_optimal(n, sigma_min, sigma_max, device):
+    alpha_min = torch.arctan(torch.tensor(sigma_min, device=device))
+    alpha_max = torch.arctan(torch.tensor(sigma_max, device=device))
+    step_indices = torch.arange(n + 1, device=device)
+    sigmas = torch.tan(step_indices / n * alpha_min + (1.0 - step_indices / n) * alpha_max)
+    return sigmas
+
 
 schedulers = [
     Scheduler('automatic', 'Automatic', None),
@@ -38,6 +76,8 @@ schedulers = [
     Scheduler('exponential', 'Exponential', k_diffusion.sampling.get_sigmas_exponential),
     Scheduler('polyexponential', 'Polyexponential', k_diffusion.sampling.get_sigmas_polyexponential, default_rho=1.0),
     Scheduler('sgm_uniform', 'SGM Uniform', sgm_uniform, need_inner_model=True, aliases=["SGMUniform"]),
+    Scheduler('kl_optimal', 'KL Optimal', kl_optimal),
+    Scheduler('align_your_steps', 'Align Your Steps', get_align_your_steps_sigmas),
 ]
 
 schedulers_map = {**{x.name: x for x in schedulers}, **{x.label: x for x in schedulers}}

+ 8 - 0
modules/shared_init.py

@@ -31,6 +31,14 @@ def initialize():
     devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
     devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype
 
+    if cmd_opts.precision == "half":
+        msg = "--no-half and --no-half-vae conflict with --precision half"
+        assert devices.dtype == torch.float16, msg
+        assert devices.dtype_vae == torch.float16, msg
+        assert devices.dtype_inference == torch.float16, msg
+        devices.force_fp16 = True
+        devices.force_model_fp16()
+
     shared.device = devices.device
     shared.weight_load_location = None if cmd_opts.lowram else "cpu"
 

+ 4 - 2
modules/shared_options.py

@@ -209,7 +209,8 @@ options_templates.update(options_section(('img2img', "img2img", "sd"), {
 
 options_templates.update(options_section(('optimizations', "Optimizations", "sd"), {
     "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
-    "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
+    "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}, infotext='NGMS').link("PR", "https://github.com/AUTOMATIC1111/stablediffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
+    "s_min_uncond_all": OptionInfo(False, "Negative Guidance minimum sigma all steps", infotext='NGMS all steps').info("By default, NGMS above skips every other step; this makes it skip all steps"),
     "token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
     "token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
     "token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio hr').info("only applies if non-zero and overrides above"),
@@ -380,7 +381,8 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
     'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}, infotext='UniPC skip type'),
     'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, infotext='UniPC order').info("must be < sampling steps"),
     'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'),
-    'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise Schedule").info("for use with zero terminal SNR trained models")
+    'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise Schedule").info("for use with zero terminal SNR trained models"),
+    'skip_early_cond': OptionInfo(0.0, "Ignore negative prompt during early sampling", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext="Skip Early CFG").info("disables CFG on a proportion of steps at the beginning of generation; 0=skip none; 1=skip all; can both improve sample diversity/quality and speed up sampling"),
 }))
 
 options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), {

+ 1 - 0
requirements_versions.txt

@@ -1,3 +1,4 @@
+setuptools==69.5.1  # temp fix for compatibility with some old packages
 GitPython==3.1.32
 Pillow==9.5.0
 accelerate==0.21.0

+ 11 - 33
scripts/xyz_grid.py

@@ -106,17 +106,6 @@ def confirm_range(min_val, max_val, axis_label):
     return confirm_range_fun
 
 
-def apply_clip_skip(p, x, xs):
-    opts.data["CLIP_stop_at_last_layers"] = x
-
-
-def apply_upscale_latent_space(p, x, xs):
-    if x.lower().strip() != '0':
-        opts.data["use_scale_latent_for_hires_fix"] = True
-    else:
-        opts.data["use_scale_latent_for_hires_fix"] = False
-
-
 def apply_size(p, x: str, xs) -> None:
     try:
         width, _, height = x.partition('x')
@@ -129,21 +118,16 @@ def apply_size(p, x: str, xs) -> None:
 
 
 def find_vae(name: str):
-    if name.lower() in ['auto', 'automatic']:
-        return modules.sd_vae.unspecified
-    if name.lower() == 'none':
-        return None
-    else:
-        choices = [x for x in sorted(modules.sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()]
-        if len(choices) == 0:
-            print(f"No VAE found for {name}; using automatic")
-            return modules.sd_vae.unspecified
-        else:
-            return modules.sd_vae.vae_dict[choices[0]]
+    match name := name.lower().strip():
+        case 'auto', 'automatic':
+            return 'Automatic'
+        case 'none':
+            return 'None'
+    return next((k for k in modules.sd_vae.vae_dict if k.lower() == name), print(f'No VAE found for {name}; using Automatic') or 'Automatic')
 
 
 def apply_vae(p, x, xs):
-    modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x))
+    p.override_settings['sd_vae'] = find_vae(x)
 
 
 def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):
@@ -151,7 +135,7 @@ def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):
 
 
 def apply_uni_pc_order(p, x, xs):
-    opts.data["uni_pc_order"] = min(x, p.steps - 1)
+    p.override_settings['uni_pc_order'] = min(x, p.steps - 1)
 
 
 def apply_face_restore(p, opt, x):
@@ -277,13 +261,13 @@ axis_options = [
     AxisOption("Schedule max sigma", float, apply_override("sigma_max")),
     AxisOption("Schedule rho", float, apply_override("rho")),
     AxisOption("Eta", float, apply_field("eta")),
-    AxisOption("Clip skip", int, apply_clip_skip),
+    AxisOption("Clip skip", int, apply_override('CLIP_stop_at_last_layers')),
     AxisOption("Denoising", float, apply_field("denoising_strength")),
     AxisOption("Initial noise multiplier", float, apply_field("initial_noise_multiplier")),
     AxisOption("Extra noise", float, apply_override("img2img_extra_noise")),
     AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]),
     AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),
-    AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: ['None'] + list(sd_vae.vae_dict)),
+    AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: ['Automatic', 'None'] + list(sd_vae.vae_dict)),
     AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),
     AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5),
     AxisOption("Face restore", str, apply_face_restore, format_value=format_value),
@@ -412,18 +396,12 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
 
 class SharedSettingsStackHelper(object):
     def __enter__(self):
-        self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
-        self.vae = opts.sd_vae
-        self.uni_pc_order = opts.uni_pc_order
+        pass
 
     def __exit__(self, exc_type, exc_value, tb):
-        opts.data["sd_vae"] = self.vae
-        opts.data["uni_pc_order"] = self.uni_pc_order
         modules.sd_models.reload_model_weights()
         modules.sd_vae.reload_vae_weights()
 
-        opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
-
 
 re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
 re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")

+ 6 - 1
webui-macos-env.sh

@@ -11,7 +11,12 @@ fi
 
 export install_dir="$HOME"
 export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate"
-export TORCH_COMMAND="pip install torch==2.1.0 torchvision==0.16.0"
 export PYTORCH_ENABLE_MPS_FALLBACK=1
 
+if [[ "$(sysctl -n machdep.cpu.brand_string)" =~ ^.*"Intel".*$ ]]; then
+    export TORCH_COMMAND="pip install torch==2.1.2 torchvision==0.16.2"
+else
+    export TORCH_COMMAND="pip install torch==2.3.0 torchvision==0.18.0"
+fi
+
 ####################################################################