Эх сурвалжийг харах

Merge branch 'dev' into master

AUTOMATIC1111 2 жил өмнө
parent
commit
b7c5b30f14
49 өөрчлөгдсөн 705 нэмэгдсэн , 390 устгасан
  1. 1 1
      .github/workflows/on_pull_request.yaml
  2. 2 2
      .github/workflows/run_tests.yaml
  3. 3 0
      README.md
  4. 3 5
      extensions-builtin/LDSR/ldsr_model_arch.py
  5. 8 12
      extensions-builtin/LDSR/scripts/ldsr_model.py
  6. 1 1
      extensions-builtin/Lora/lora.py
  7. 13 17
      extensions-builtin/ScuNET/scripts/scunet_model.py
  8. 49 34
      extensions-builtin/SwinIR/scripts/swinir_model.py
  9. 29 1
      extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js
  10. 1 0
      extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py
  11. 3 2
      javascript/edit-attention.js
  12. 41 0
      javascript/edit-order.js
  13. 18 0
      javascript/extensions.js
  14. 73 56
      modules/api/api.py
  15. 0 4
      modules/api/models.py
  16. 4 1
      modules/call_queue.py
  17. 1 0
      modules/cmd_args.py
  18. 1 5
      modules/codeformer_model.py
  19. 4 7
      modules/devices.py
  20. 10 13
      modules/esrgan_model.py
  21. 3 0
      modules/extra_networks.py
  22. 1 2
      modules/extras.py
  23. 0 29
      modules/generation_parameters_copypaste.py
  24. 1 1
      modules/gfpgan_model.py
  25. 4 26
      modules/hypernetworks/hypernetwork.py
  26. 48 21
      modules/images.py
  27. 52 20
      modules/img2img.py
  28. 1 2
      modules/interrogate.py
  29. 3 3
      modules/launch_utils.py
  30. 31 8
      modules/mac_specific.py
  31. 27 4
      modules/modelloader.py
  32. 0 14
      modules/paths.py
  33. 4 3
      modules/postprocessing.py
  34. 15 8
      modules/processing.py
  35. 15 18
      modules/realesrgan_model.py
  36. 39 1
      modules/scripts.py
  37. 12 7
      modules/sd_models.py
  38. 29 6
      modules/shared.py
  39. 44 4
      modules/textual_inversion/logging.py
  40. 1 1
      modules/textual_inversion/preprocess.py
  41. 4 2
      modules/textual_inversion/textual_inversion.py
  42. 10 7
      modules/txt2img.py
  43. 10 3
      modules/ui.py
  44. 23 6
      modules/ui_extensions.py
  45. 4 4
      modules/ui_extra_networks.py
  46. 14 7
      modules/ui_settings.py
  47. 15 2
      style.css
  48. 19 15
      webui.py
  49. 11 5
      webui.sh

+ 1 - 1
.github/workflows/on_pull_request.yaml

@@ -18,7 +18,7 @@ jobs:
           #     not to have GHA download an (at the time of writing) 4 GB cache
           #     of PyTorch and other dependencies.
       - name: Install Ruff
-        run: pip install ruff==0.0.265
+        run: pip install ruff==0.0.272
       - name: Run Ruff
         run: ruff .
   lint-js:

+ 2 - 2
.github/workflows/run_tests.yaml

@@ -42,7 +42,7 @@ jobs:
           --no-half
           --disable-opt-split-attention
           --use-cpu all
-          --add-stop-route
+          --api-server-stop
           2>&1 | tee output.txt &
       - name: Run tests
         run: |
@@ -50,7 +50,7 @@ jobs:
           python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
       - name: Kill test server
         if: always()
-        run: curl -vv -XPOST http://127.0.0.1:7860/_stop && sleep 10
+        run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-stop && sleep 10
       - name: Show coverage
         run: |
           python -m coverage combine .coverage*

+ 3 - 0
README.md

@@ -135,8 +135,11 @@ Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-w
 Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing)
 
 ## Documentation
+
 The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki).
 
+For the purposes of getting Google and other search engines to crawl the wiki, here's a link to the (not for humans) [crawlable wiki](https://github-wiki-see.page/m/AUTOMATIC1111/stable-diffusion-webui/wiki).
+
 ## Credits
 Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
 

+ 3 - 5
extensions-builtin/LDSR/ldsr_model_arch.py

@@ -12,7 +12,7 @@ import safetensors.torch
 
 from ldm.models.diffusion.ddim import DDIMSampler
 from ldm.util import instantiate_from_config, ismap
-from modules import shared, sd_hijack
+from modules import shared, sd_hijack, devices
 
 cached_ldsr_model: torch.nn.Module = None
 
@@ -112,8 +112,7 @@ class LDSR:
 
 
         gc.collect()
-        if torch.cuda.is_available:
-            torch.cuda.empty_cache()
+        devices.torch_gc()
 
         im_og = image
         width_og, height_og = im_og.size
@@ -150,8 +149,7 @@ class LDSR:
 
         del model
         gc.collect()
-        if torch.cuda.is_available:
-            torch.cuda.empty_cache()
+        devices.torch_gc()
 
         return a
 

+ 8 - 12
extensions-builtin/LDSR/scripts/ldsr_model.py

@@ -1,7 +1,6 @@
 import os
 
-from basicsr.utils.download_util import load_file_from_url
-
+from modules.modelloader import load_file_from_url
 from modules.upscaler import Upscaler, UpscalerData
 from ldsr_model_arch import LDSR
 from modules import shared, script_callbacks, errors
@@ -43,20 +42,17 @@ class UpscalerLDSR(Upscaler):
         if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
             model = local_safetensors_path
         else:
-            model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="model.ckpt", progress=True)
+            model = local_ckpt_path or load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name="model.ckpt")
 
-        yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml", progress=True)
+        yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml")
 
-        try:
-            return LDSR(model, yaml)
-        except Exception:
-            errors.report("Error importing LDSR", exc_info=True)
-        return None
+        return LDSR(model, yaml)
 
     def do_upscale(self, img, path):
-        ldsr = self.load_model(path)
-        if ldsr is None:
-            print("NO LDSR!")
+        try:
+            ldsr = self.load_model(path)
+        except Exception:
+            errors.report(f"Failed loading LDSR model {path}", exc_info=True)
             return img
         ddim_steps = shared.opts.ldsr_steps
         return ldsr.super_resolution(img, ddim_steps, self.scale)

+ 1 - 1
extensions-builtin/Lora/lora.py

@@ -443,7 +443,7 @@ def list_available_loras():
     os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
 
     candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
-    for filename in sorted(candidates, key=str.lower):
+    for filename in candidates:
         if os.path.isdir(filename):
             continue
 

+ 13 - 17
extensions-builtin/ScuNET/scripts/scunet_model.py

@@ -1,4 +1,3 @@
-import os.path
 import sys
 
 import PIL.Image
@@ -6,12 +5,11 @@ import numpy as np
 import torch
 from tqdm import tqdm
 
-from basicsr.utils.download_util import load_file_from_url
-
 import modules.upscaler
 from modules import devices, modelloader, script_callbacks, errors
-from scunet_model_arch import SCUNet as net
+from scunet_model_arch import SCUNet
 
+from modules.modelloader import load_file_from_url
 from modules.shared import opts
 
 
@@ -28,7 +26,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
         scalers = []
         add_model2 = True
         for file in model_paths:
-            if "http" in file:
+            if file.startswith("http"):
                 name = self.model_name
             else:
                 name = modelloader.friendly_name(file)
@@ -87,11 +85,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
 
     def do_upscale(self, img: PIL.Image.Image, selected_file):
 
-        torch.cuda.empty_cache()
+        devices.torch_gc()
 
-        model = self.load_model(selected_file)
-        if model is None:
-            print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr)
+        try:
+            model = self.load_model(selected_file)
+        except Exception as e:
+            print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
             return img
 
         device = devices.get_device_for('scunet')
@@ -111,7 +110,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
         torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
         np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
         del torch_img, torch_output
-        torch.cuda.empty_cache()
+        devices.torch_gc()
 
         output = np_output.transpose((1, 2, 0))  # CHW to HWC
         output = output[:, :, ::-1]  # BGR to RGB
@@ -119,15 +118,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
 
     def load_model(self, path: str):
         device = devices.get_device_for('scunet')
-        if "http" in path:
-            filename = load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="%s.pth" % self.name, progress=True)
+        if path.startswith("http"):
+            # TODO: this doesn't use `path` at all?
+            filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
         else:
             filename = path
-        if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
-            print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
-            return None
-
-        model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
+        model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
         model.load_state_dict(torch.load(filename), strict=True)
         model.eval()
         for _, v in model.named_parameters():

+ 49 - 34
extensions-builtin/SwinIR/scripts/swinir_model.py

@@ -1,34 +1,35 @@
-import os
+import sys
+import platform
 
 import numpy as np
 import torch
 from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
 from tqdm import tqdm
 
 from modules import modelloader, devices, script_callbacks, shared
 from modules.shared import opts, state
-from swinir_model_arch import SwinIR as net
-from swinir_model_arch_v2 import Swin2SR as net2
+from swinir_model_arch import SwinIR
+from swinir_model_arch_v2 import Swin2SR
 from modules.upscaler import Upscaler, UpscalerData
 
+SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
 
 device_swinir = devices.get_device_for('swinir')
 
 
 class UpscalerSwinIR(Upscaler):
     def __init__(self, dirname):
+        self._cached_model = None           # keep the model when SWIN_torch_compile is on to prevent re-compile every runs
+        self._cached_model_config = None    # to clear '_cached_model' when changing model (v1/v2) or settings
         self.name = "SwinIR"
-        self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \
-                         "/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
-                         "-L_x4_GAN.pth "
+        self.model_url = SWINIR_MODEL_URL
         self.model_name = "SwinIR 4x"
         self.user_path = dirname
         super().__init__()
         scalers = []
         model_files = self.find_models(ext_filter=[".pt", ".pth"])
         for model in model_files:
-            if "http" in model:
+            if model.startswith("http"):
                 name = self.model_name
             else:
                 name = modelloader.friendly_name(model)
@@ -37,42 +38,54 @@ class UpscalerSwinIR(Upscaler):
         self.scalers = scalers
 
     def do_upscale(self, img, model_file):
-        model = self.load_model(model_file)
-        if model is None:
-            return img
-        model = model.to(device_swinir, dtype=devices.dtype)
+        use_compile = hasattr(opts, 'SWIN_torch_compile') and opts.SWIN_torch_compile \
+            and int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows"
+        current_config = (model_file, opts.SWIN_tile)
+
+        if use_compile and self._cached_model_config == current_config:
+            model = self._cached_model
+        else:
+            self._cached_model = None
+            try:
+                model = self.load_model(model_file)
+            except Exception as e:
+                print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
+                return img
+            model = model.to(device_swinir, dtype=devices.dtype)
+            if use_compile:
+                model = torch.compile(model)
+                self._cached_model = model
+                self._cached_model_config = current_config
         img = upscale(img, model)
-        try:
-            torch.cuda.empty_cache()
-        except Exception:
-            pass
+        devices.torch_gc()
         return img
 
     def load_model(self, path, scale=4):
-        if "http" in path:
-            dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth")
-            filename = load_file_from_url(url=path, model_dir=self.model_download_path, file_name=dl_name, progress=True)
+        if path.startswith("http"):
+            filename = modelloader.load_file_from_url(
+                url=path,
+                model_dir=self.model_download_path,
+                file_name=f"{self.model_name.replace(' ', '_')}.pth",
+            )
         else:
             filename = path
-        if filename is None or not os.path.exists(filename):
-            return None
         if filename.endswith(".v2.pth"):
-            model = net2(
-            upscale=scale,
-            in_chans=3,
-            img_size=64,
-            window_size=8,
-            img_range=1.0,
-            depths=[6, 6, 6, 6, 6, 6],
-            embed_dim=180,
-            num_heads=[6, 6, 6, 6, 6, 6],
-            mlp_ratio=2,
-            upsampler="nearest+conv",
-            resi_connection="1conv",
+            model = Swin2SR(
+                upscale=scale,
+                in_chans=3,
+                img_size=64,
+                window_size=8,
+                img_range=1.0,
+                depths=[6, 6, 6, 6, 6, 6],
+                embed_dim=180,
+                num_heads=[6, 6, 6, 6, 6, 6],
+                mlp_ratio=2,
+                upsampler="nearest+conv",
+                resi_connection="1conv",
             )
             params = None
         else:
-            model = net(
+            model = SwinIR(
                 upscale=scale,
                 in_chans=3,
                 img_size=64,
@@ -172,6 +185,8 @@ def on_ui_settings():
 
     shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
     shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
+    if int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows":    # torch.compile() require pytorch 2.0 or above, and not on Windows
+        shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
 
 
 script_callbacks.on_ui_settings(on_ui_settings)

+ 29 - 1
extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js

@@ -200,7 +200,8 @@ onUiLoaded(async() => {
         canvas_hotkey_move: "KeyF",
         canvas_hotkey_overlap: "KeyO",
         canvas_disabled_functions: [],
-        canvas_show_tooltip: true
+        canvas_show_tooltip: true,
+        canvas_blur_prompt: false
     };
 
     const functionMap = {
@@ -608,6 +609,19 @@ onUiLoaded(async() => {
 
         // Handle keydown events
         function handleKeyDown(event) {
+            // Disable key locks to make pasting from the buffer work correctly
+            if ((event.ctrlKey && event.code === 'KeyV') || (event.ctrlKey && event.code === 'KeyC') || event.code === "F5") {
+                return;
+            }
+
+            // before activating shortcut, ensure user is not actively typing in an input field
+            if (!hotkeysConfig.canvas_blur_prompt) {
+                if (event.target.nodeName === 'TEXTAREA' || event.target.nodeName === 'INPUT') {
+                    return;
+                }
+            }
+
+
             const hotkeyActions = {
                 [hotkeysConfig.canvas_hotkey_reset]: resetZoom,
                 [hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap,
@@ -686,6 +700,20 @@ onUiLoaded(async() => {
 
         // Handle the move event for pan functionality. Updates the panX and panY variables and applies the new transform to the target element.
         function handleMoveKeyDown(e) {
+
+            // Disable key locks to make pasting from the buffer work correctly
+            if ((e.ctrlKey && e.code === 'KeyV') || (e.ctrlKey && event.code === 'KeyC') || e.code === "F5") {
+                return;
+            }
+
+            // before activating shortcut, ensure user is not actively typing in an input field
+            if (!hotkeysConfig.canvas_blur_prompt) {
+                if (e.target.nodeName === 'TEXTAREA' || e.target.nodeName === 'INPUT') {
+                    return;
+                }
+            }
+
+
             if (e.code === hotkeysConfig.canvas_hotkey_move) {
                 if (!e.ctrlKey && !e.metaKey && isKeyDownHandlerAttached) {
                     e.preventDefault();

+ 1 - 0
extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py

@@ -9,5 +9,6 @@ shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas
     "canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
     "canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"),
     "canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
+    "canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"),
     "canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
 }))

+ 3 - 2
javascript/edit-attention.js

@@ -100,11 +100,12 @@ function keyupEditAttention(event) {
     if (String(weight).length == 1) weight += ".0";
 
     if (closeCharacter == ')' && weight == 1) {
-        text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5);
+        var endParenPos = text.substring(selectionEnd).indexOf(')');
+        text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + endParenPos + 1);
         selectionStart--;
         selectionEnd--;
     } else {
-        text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
+        text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + end);
     }
 
     target.focus();

+ 41 - 0
javascript/edit-order.js

@@ -0,0 +1,41 @@
+/* alt+left/right moves text in prompt */
+
+function keyupEditOrder(event) {
+    if (!opts.keyedit_move) return;
+
+    let target = event.originalTarget || event.composedPath()[0];
+    if (!target.matches("*:is([id*='_toprow'] [id*='_prompt'], .prompt) textarea")) return;
+    if (!event.altKey) return;
+
+    let isLeft = event.key == "ArrowLeft";
+    let isRight = event.key == "ArrowRight";
+    if (!isLeft && !isRight) return;
+    event.preventDefault();
+
+    let selectionStart = target.selectionStart;
+    let selectionEnd = target.selectionEnd;
+    let text = target.value;
+    let items = text.split(",");
+    let indexStart = (text.slice(0, selectionStart).match(/,/g) || []).length;
+    let indexEnd = (text.slice(0, selectionEnd).match(/,/g) || []).length;
+    let range = indexEnd - indexStart + 1;
+
+    if (isLeft && indexStart > 0) {
+        items.splice(indexStart - 1, 0, ...items.splice(indexStart, range));
+        target.value = items.join();
+        target.selectionStart = items.slice(0, indexStart - 1).join().length + (indexStart == 1 ? 0 : 1);
+        target.selectionEnd = items.slice(0, indexEnd).join().length;
+    } else if (isRight && indexEnd < items.length - 1) {
+        items.splice(indexStart + 1, 0, ...items.splice(indexStart, range));
+        target.value = items.join();
+        target.selectionStart = items.slice(0, indexStart + 1).join().length + 1;
+        target.selectionEnd = items.slice(0, indexEnd + 2).join().length;
+    }
+
+    event.preventDefault();
+    updateInput(target);
+}
+
+addEventListener('keydown', (event) => {
+    keyupEditOrder(event);
+});

+ 18 - 0
javascript/extensions.js

@@ -72,3 +72,21 @@ function config_state_confirm_restore(_, config_state_name, config_restore_type)
     }
     return [confirmed, config_state_name, config_restore_type];
 }
+
+function toggle_all_extensions(event) {
+    gradioApp().querySelectorAll('#extensions .extension_toggle').forEach(function(checkbox_el) {
+        checkbox_el.checked = event.target.checked;
+    });
+}
+
+function toggle_extension() {
+    let all_extensions_toggled = true;
+    for (const checkbox_el of gradioApp().querySelectorAll('#extensions .extension_toggle')) {
+        if (!checkbox_el.checked) {
+            all_extensions_toggled = false;
+            break;
+        }
+    }
+
+    gradioApp().querySelector('#extensions .all_extensions_toggle').checked = all_extensions_toggled;
+}

+ 73 - 56
modules/api/api.py

@@ -14,7 +14,7 @@ from fastapi.encoders import jsonable_encoder
 from secrets import compare_digest
 
 import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart
 from modules.api import models
 from modules.shared import opts
 from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
@@ -22,7 +22,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
 from modules.textual_inversion.preprocess import preprocess
 from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
 from PIL import PngImagePlugin,Image
-from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
+from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_aliases
 from modules.sd_vae import vae_dict
 from modules.sd_models_config import find_checkpoint_config_near_filename
 from modules.realesrgan_model import get_realesrgan_models
@@ -30,13 +30,7 @@ from modules import devices
 from typing import Dict, List, Any
 import piexif
 import piexif.helper
-
-
-def upscaler_to_index(name: str):
-    try:
-        return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
-    except Exception as e:
-        raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e
+from contextlib import closing
 
 
 def script_name_to_index(name, scripts):
@@ -84,6 +78,8 @@ def encode_pil_to_base64(image):
             image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
 
         elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
+            if image.mode == "RGBA":
+                image = image.convert("RGB")
             parameters = image.info.get('parameters', None)
             exif_bytes = piexif.dump({
                 "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
@@ -209,6 +205,11 @@ class Api:
         self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
         self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
 
+        if shared.cmd_opts.api_server_stop:
+            self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
+            self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"])
+            self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"])
+
         self.default_script_arg_txt2img = []
         self.default_script_arg_img2img = []
 
@@ -324,19 +325,19 @@ class Api:
         args.pop('save_images', None)
 
         with self.queue_lock:
-            p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
-            p.scripts = script_runner
-            p.outpath_grids = opts.outdir_txt2img_grids
-            p.outpath_samples = opts.outdir_txt2img_samples
-
-            shared.state.begin()
-            if selectable_scripts is not None:
-                p.script_args = script_args
-                processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
-            else:
-                p.script_args = tuple(script_args) # Need to pass args as tuple here
-                processed = process_images(p)
-            shared.state.end()
+            with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
+                p.scripts = script_runner
+                p.outpath_grids = opts.outdir_txt2img_grids
+                p.outpath_samples = opts.outdir_txt2img_samples
+
+                shared.state.begin(job="scripts_txt2img")
+                if selectable_scripts is not None:
+                    p.script_args = script_args
+                    processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
+                else:
+                    p.script_args = tuple(script_args) # Need to pass args as tuple here
+                    processed = process_images(p)
+                shared.state.end()
 
         b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
 
@@ -380,20 +381,20 @@ class Api:
         args.pop('save_images', None)
 
         with self.queue_lock:
-            p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
-            p.init_images = [decode_base64_to_image(x) for x in init_images]
-            p.scripts = script_runner
-            p.outpath_grids = opts.outdir_img2img_grids
-            p.outpath_samples = opts.outdir_img2img_samples
-
-            shared.state.begin()
-            if selectable_scripts is not None:
-                p.script_args = script_args
-                processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
-            else:
-                p.script_args = tuple(script_args) # Need to pass args as tuple here
-                processed = process_images(p)
-            shared.state.end()
+            with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
+                p.init_images = [decode_base64_to_image(x) for x in init_images]
+                p.scripts = script_runner
+                p.outpath_grids = opts.outdir_img2img_grids
+                p.outpath_samples = opts.outdir_img2img_samples
+
+                shared.state.begin(job="scripts_img2img")
+                if selectable_scripts is not None:
+                    p.script_args = script_args
+                    processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
+                else:
+                    p.script_args = tuple(script_args) # Need to pass args as tuple here
+                    processed = process_images(p)
+                shared.state.end()
 
         b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
 
@@ -517,6 +518,10 @@ class Api:
         return options
 
     def set_config(self, req: Dict[str, Any]):
+        checkpoint_name = req.get("sd_model_checkpoint", None)
+        if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases:
+            raise RuntimeError(f"model {checkpoint_name!r} not found")
+
         for k, v in req.items():
             shared.opts.set(k, v)
 
@@ -598,44 +603,42 @@ class Api:
 
     def create_embedding(self, args: dict):
         try:
-            shared.state.begin()
+            shared.state.begin(job="create_embedding")
             filename = create_embedding(**args) # create empty embedding
             sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
-            shared.state.end()
             return models.CreateResponse(info=f"create embedding filename: {filename}")
         except AssertionError as e:
-            shared.state.end()
             return models.TrainResponse(info=f"create embedding error: {e}")
+        finally:
+            shared.state.end()
+
 
     def create_hypernetwork(self, args: dict):
         try:
-            shared.state.begin()
+            shared.state.begin(job="create_hypernetwork")
             filename = create_hypernetwork(**args) # create empty embedding
-            shared.state.end()
             return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
         except AssertionError as e:
-            shared.state.end()
             return models.TrainResponse(info=f"create hypernetwork error: {e}")
+        finally:
+            shared.state.end()
 
     def preprocess(self, args: dict):
         try:
-            shared.state.begin()
+            shared.state.begin(job="preprocess")
             preprocess(**args) # quick operation unless blip/booru interrogation is enabled
             shared.state.end()
-            return models.PreprocessResponse(info = 'preprocess complete')
+            return models.PreprocessResponse(info='preprocess complete')
         except KeyError as e:
-            shared.state.end()
             return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
-        except AssertionError as e:
-            shared.state.end()
+        except Exception as e:
             return models.PreprocessResponse(info=f"preprocess error: {e}")
-        except FileNotFoundError as e:
+        finally:
             shared.state.end()
-            return models.PreprocessResponse(info=f'preprocess error: {e}')
 
     def train_embedding(self, args: dict):
         try:
-            shared.state.begin()
+            shared.state.begin(job="train_embedding")
             apply_optimizations = shared.opts.training_xattention_optimizations
             error = None
             filename = ''
@@ -648,15 +651,15 @@ class Api:
             finally:
                 if not apply_optimizations:
                     sd_hijack.apply_optimizations()
-                shared.state.end()
             return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
-        except AssertionError as msg:
-            shared.state.end()
+        except Exception as msg:
             return models.TrainResponse(info=f"train embedding error: {msg}")
+        finally:
+            shared.state.end()
 
     def train_hypernetwork(self, args: dict):
         try:
-            shared.state.begin()
+            shared.state.begin(job="train_hypernetwork")
             shared.loaded_hypernetworks = []
             apply_optimizations = shared.opts.training_xattention_optimizations
             error = None
@@ -674,9 +677,10 @@ class Api:
                     sd_hijack.apply_optimizations()
                 shared.state.end()
             return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
-        except AssertionError:
+        except Exception as exc:
+            return models.TrainResponse(info=f"train embedding error: {exc}")
+        finally:
             shared.state.end()
-            return models.TrainResponse(info=f"train embedding error: {error}")
 
     def get_memory(self):
         try:
@@ -716,3 +720,16 @@ class Api:
     def launch(self, server_name, port):
         self.app.include_router(self.router)
         uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive)
+
+    def kill_webui(self):
+        restart.stop_program()
+
+    def restart_webui(self):
+        if restart.is_restartable():
+            restart.restart_program()
+        return Response(status_code=501)
+
+    def stop_webui(request):
+        shared.state.server_command = "stop"
+        return Response("Stopping.")
+

+ 0 - 4
modules/api/models.py

@@ -274,10 +274,6 @@ class PromptStyleItem(BaseModel):
     prompt: Optional[str] = Field(title="Prompt")
     negative_prompt: Optional[str] = Field(title="Negative Prompt")
 
-class ArtistItem(BaseModel):
-    name: str = Field(title="Name")
-    score: float = Field(title="Score")
-    category: str = Field(title="Category")
 
 class EmbeddingItem(BaseModel):
     step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")

+ 4 - 1
modules/call_queue.py

@@ -1,3 +1,4 @@
+from functools import wraps
 import html
 import threading
 import time
@@ -18,6 +19,7 @@ def wrap_queued_call(func):
 
 
 def wrap_gradio_gpu_call(func, extra_outputs=None):
+    @wraps(func)
     def f(*args, **kwargs):
 
         # if the first argument is a string that says "task(...)", it is treated as a job id
@@ -28,7 +30,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
             id_task = None
 
         with queue_lock:
-            shared.state.begin()
+            shared.state.begin(job=id_task)
             progress.start_task(id_task)
 
             try:
@@ -45,6 +47,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
 
 
 def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
+    @wraps(func)
     def f(*args, extra_outputs_array=extra_outputs, **kwargs):
         run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
         if run_memmon:

+ 1 - 0
modules/cmd_args.py

@@ -107,4 +107,5 @@ parser.add_argument("--no-hashing", action='store_true', help="disable sha256 ha
 parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
 parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
 parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
+parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
 parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')

+ 1 - 5
modules/codeformer_model.py

@@ -15,7 +15,6 @@ model_dir = "Codeformer"
 model_path = os.path.join(models_path, model_dir)
 model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
 
-have_codeformer = False
 codeformer = None
 
 
@@ -100,7 +99,7 @@ def setup_model(dirname):
                             output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
                             restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
                         del output
-                        torch.cuda.empty_cache()
+                        devices.torch_gc()
                     except Exception:
                         errors.report('Failed inference for CodeFormer', exc_info=True)
                         restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
@@ -123,9 +122,6 @@ def setup_model(dirname):
 
                 return restored_img
 
-        global have_codeformer
-        have_codeformer = True
-
         global codeformer
         codeformer = FaceRestorerCodeFormer(dirname)
         shared.face_restorers.append(codeformer)

+ 4 - 7
modules/devices.py

@@ -15,13 +15,6 @@ def has_mps() -> bool:
     else:
         return mac_specific.has_mps
 
-def extract_device_id(args, name):
-    for x in range(len(args)):
-        if name in args[x]:
-            return args[x + 1]
-
-    return None
-
 
 def get_cuda_device_string():
     from modules import shared
@@ -56,11 +49,15 @@ def get_device_for(task):
 
 
 def torch_gc():
+
     if torch.cuda.is_available():
         with torch.cuda.device(get_cuda_device_string()):
             torch.cuda.empty_cache()
             torch.cuda.ipc_collect()
 
+    if has_mps():
+        mac_specific.torch_mps_gc()
+
 
 def enable_tf32():
     if torch.cuda.is_available():

+ 10 - 13
modules/esrgan_model.py

@@ -1,15 +1,13 @@
-import os
+import sys
 
 import numpy as np
 import torch
 from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
 
 import modules.esrgan_model_arch as arch
 from modules import modelloader, images, devices
-from modules.upscaler import Upscaler, UpscalerData
 from modules.shared import opts
-
+from modules.upscaler import Upscaler, UpscalerData
 
 
 def mod2normal(state_dict):
@@ -134,7 +132,7 @@ class UpscalerESRGAN(Upscaler):
             scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
             scalers.append(scaler_data)
         for file in model_paths:
-            if "http" in file:
+            if file.startswith("http"):
                 name = self.model_name
             else:
                 name = modelloader.friendly_name(file)
@@ -143,26 +141,25 @@ class UpscalerESRGAN(Upscaler):
             self.scalers.append(scaler_data)
 
     def do_upscale(self, img, selected_model):
-        model = self.load_model(selected_model)
-        if model is None:
+        try:
+            model = self.load_model(selected_model)
+        except Exception as e:
+            print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
             return img
         model.to(devices.device_esrgan)
         img = esrgan_upscale(model, img)
         return img
 
     def load_model(self, path: str):
-        if "http" in path:
-            filename = load_file_from_url(
+        if path.startswith("http"):
+            # TODO: this doesn't use `path` at all?
+            filename = modelloader.load_file_from_url(
                 url=self.model_url,
                 model_dir=self.model_download_path,
                 file_name=f"{self.model_name}.pth",
-                progress=True,
             )
         else:
             filename = path
-        if not os.path.exists(filename) or filename is None:
-            print(f"Unable to load {self.model_path} from {filename}")
-            return None
 
         state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
 

+ 3 - 0
modules/extra_networks.py

@@ -103,6 +103,9 @@ def activate(p, extra_network_data):
         except Exception as e:
             errors.display(e, f"activating extra network {extra_network_name}")
 
+    if p.scripts is not None:
+        p.scripts.after_extra_networks_activate(p, batch_number=p.iteration, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds, extra_network_data=extra_network_data)
+
 
 def deactivate(p, extra_network_data):
     """call deactivate for extra networks in extra_network_data in specified order, then call

+ 1 - 2
modules/extras.py

@@ -73,8 +73,7 @@ def to_half(tensor, enable):
 
 
 def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
-    shared.state.begin()
-    shared.state.job = 'model-merge'
+    shared.state.begin(job="model-merge")
 
     def fail(message):
         shared.state.textinfo = message

+ 0 - 29
modules/generation_parameters_copypaste.py

@@ -174,31 +174,6 @@ def send_image_and_dimensions(x):
     return img, w, h
 
 
-
-def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
-    """Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
-
-    Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
-    parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
-
-    If the infotext has no hash, then a hypernet with the same name will be selected instead.
-    """
-    hypernet_name = hypernet_name.lower()
-    if hypernet_hash is not None:
-        # Try to match the hash in the name
-        for hypernet_key in shared.hypernetworks.keys():
-            result = re_hypernet_hash.search(hypernet_key)
-            if result is not None and result[1] == hypernet_hash:
-                return hypernet_key
-    else:
-        # Fall back to a hypernet with the same name
-        for hypernet_key in shared.hypernetworks.keys():
-            if hypernet_key.lower().startswith(hypernet_name):
-                return hypernet_key
-
-    return None
-
-
 def restore_old_hires_fix_params(res):
     """for infotexts that specify old First pass size parameter, convert it into
     width, height, and hr scale"""
@@ -332,10 +307,6 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
     return res
 
 
-settings_map = {}
-
-
-
 infotext_to_setting_name_mapping = [
     ('Clip skip', 'CLIP_stop_at_last_layers', ),
     ('Conditional mask weight', 'inpainting_mask_weight'),

+ 1 - 1
modules/gfpgan_model.py

@@ -25,7 +25,7 @@ def gfpgann():
         return None
 
     models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
-    if len(models) == 1 and "http" in models[0]:
+    if len(models) == 1 and models[0].startswith("http"):
         model_file = models[0]
     elif len(models) != 0:
         latest_file = max(models, key=os.path.getctime)

+ 4 - 26
modules/hypernetworks/hypernetwork.py

@@ -3,6 +3,7 @@ import glob
 import html
 import os
 import inspect
+from contextlib import closing
 
 import modules.textual_inversion.dataset
 import torch
@@ -353,17 +354,6 @@ def load_hypernetworks(names, multipliers=None):
         shared.loaded_hypernetworks.append(hypernetwork)
 
 
-def find_closest_hypernetwork_name(search: str):
-    if not search:
-        return None
-    search = search.lower()
-    applicable = [name for name in shared.hypernetworks if search in name.lower()]
-    if not applicable:
-        return None
-    applicable = sorted(applicable, key=lambda name: len(name))
-    return applicable[0]
-
-
 def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
     hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
 
@@ -446,18 +436,6 @@ def statistics(data):
     return total_information, recent_information
 
 
-def report_statistics(loss_info:dict):
-    keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
-    for key in keys:
-        try:
-            print("Loss statistics for file " + key)
-            info, recent = statistics(list(loss_info[key]))
-            print(info)
-            print(recent)
-        except Exception as e:
-            print(e)
-
-
 def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
     # Remove illegal characters from name.
     name = "".join( x for x in name if (x.isalnum() or x in "._- "))
@@ -734,8 +712,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
 
                     preview_text = p.prompt
 
-                    processed = processing.process_images(p)
-                    image = processed.images[0] if len(processed.images) > 0 else None
+                    with closing(p):
+                        processed = processing.process_images(p)
+                        image = processed.images[0] if len(processed.images) > 0 else None
 
                     if unload:
                         shared.sd_model.cond_stage_model.to(devices.cpu)
@@ -770,7 +749,6 @@ Last saved image: {html.escape(last_saved_image)}<br/>
         pbar.leave = False
         pbar.close()
         hypernetwork.eval()
-        #report_statistics(loss_dict)
         sd_hijack_checkpoint.remove()
 
 

+ 48 - 21
modules/images.py

@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import datetime
 
 import pytz
@@ -10,7 +12,7 @@ import re
 import numpy as np
 import piexif
 import piexif.helper
-from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
+from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin
 import string
 import json
 import hashlib
@@ -139,6 +141,11 @@ class GridAnnotation:
 
 
 def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
+
+    color_active = ImageColor.getcolor(opts.grid_text_active_color, 'RGB')
+    color_inactive = ImageColor.getcolor(opts.grid_text_inactive_color, 'RGB')
+    color_background = ImageColor.getcolor(opts.grid_background_color, 'RGB')
+
     def wrap(drawing, text, font, line_length):
         lines = ['']
         for word in text.split():
@@ -168,9 +175,6 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
 
     fnt = get_font(fontsize)
 
-    color_active = (0, 0, 0)
-    color_inactive = (153, 153, 153)
-
     pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
 
     cols = im.width // width
@@ -179,7 +183,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
     assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
     assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
 
-    calc_img = Image.new("RGB", (1, 1), "white")
+    calc_img = Image.new("RGB", (1, 1), color_background)
     calc_d = ImageDraw.Draw(calc_img)
 
     for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
@@ -200,7 +204,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
 
     pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
 
-    result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), "white")
+    result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), color_background)
 
     for row in range(rows):
         for col in range(cols):
@@ -302,12 +306,14 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None):
 
         if ratio < src_ratio:
             fill_height = height // 2 - src_h // 2
-            res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
-            res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
+            if fill_height > 0:
+                res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
+                res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
         elif ratio > src_ratio:
             fill_width = width // 2 - src_w // 2
-            res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
-            res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
+            if fill_width > 0:
+                res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
+                res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
 
     return res
 
@@ -372,8 +378,8 @@ class FilenameGenerator:
         'hasprompt': lambda self, *args: self.hasprompt(*args),  # accepts formats:[hasprompt<prompt1|default><prompt2>..]
         'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
         'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
+        'user': lambda self: self.p.user,
         'vae_filename': lambda self: self.get_vae_filename(),
-
     }
     default_time_format = '%Y%m%d%H%M%S'
 
@@ -497,13 +503,23 @@ def get_next_sequence_number(path, basename):
     return result + 1
 
 
-def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None):
+def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None, pnginfo_section_name='parameters'):
+    """
+    Saves image to filename, including geninfo as text information for generation info.
+    For PNG images, geninfo is added to existing pnginfo dictionary using the pnginfo_section_name argument as key.
+    For JPG images, there's no dictionary and geninfo just replaces the EXIF description.
+    """
+
     if extension is None:
         extension = os.path.splitext(filename)[1]
 
     image_format = Image.registered_extensions()[extension]
 
     if extension.lower() == '.png':
+        existing_pnginfo = existing_pnginfo or {}
+        if opts.enable_pnginfo:
+            existing_pnginfo[pnginfo_section_name] = geninfo
+
         if opts.enable_pnginfo:
             pnginfo_data = PngImagePlugin.PngInfo()
             for k, v in (existing_pnginfo or {}).items():
@@ -622,7 +638,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
         """
         temp_file_path = f"{filename_without_extension}.tmp"
 
-        save_image_with_geninfo(image_to_save, info, temp_file_path, extension, params.pnginfo)
+        save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name)
 
         os.replace(temp_file_path, filename_without_extension + extension)
 
@@ -639,12 +655,18 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
     oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
     if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
         ratio = image.width / image.height
-
+        resize_to = None
         if oversize and ratio > 1:
-            image = image.resize((round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)), LANCZOS)
+            resize_to = round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)
         elif oversize:
-            image = image.resize((round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)), LANCZOS)
+            resize_to = round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)
 
+        if resize_to is not None:
+            try:
+                # Resizing image with LANCZOS could throw an exception if e.g. image mode is I;16
+                image = image.resize(resize_to, LANCZOS)
+            except Exception:
+                image = image.resize(resize_to)
         try:
             _atomically_save_image(image, fullfn_without_extension, ".jpg")
         except Exception as e:
@@ -662,8 +684,15 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
     return fullfn, txt_fullfn
 
 
-def read_info_from_image(image):
-    items = image.info or {}
+IGNORED_INFO_KEYS = {
+    'jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
+    'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
+    'icc_profile', 'chromaticity', 'photoshop',
+}
+
+
+def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
+    items = (image.info or {}).copy()
 
     geninfo = items.pop('parameters', None)
 
@@ -679,9 +708,7 @@ def read_info_from_image(image):
             items['exif comment'] = exif_comment
             geninfo = exif_comment
 
-    for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
-                    'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
-                    'icc_profile', 'chromaticity']:
+    for field in IGNORED_INFO_KEYS:
         items.pop(field, None)
 
     if items.get("Software", None) == "NovelAI":

+ 52 - 20
modules/img2img.py

@@ -1,23 +1,26 @@
 import os
+from contextlib import closing
 from pathlib import Path
 
 import numpy as np
 from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
+import gradio as gr
 
-from modules import sd_samplers
-from modules.generation_parameters_copypaste import create_override_settings_dict
+from modules import sd_samplers, images as imgutil
+from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
 from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
 from modules.shared import opts, state
+from modules.images import save_image
 import modules.shared as shared
 import modules.processing as processing
 from modules.ui import plaintext_to_html
 import modules.scripts
 
 
-def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0):
+def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
     processing.fix_seed(p)
 
-    images = shared.listfiles(input_dir)
+    images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp")))
 
     is_inpaint_batch = False
     if inpaint_mask_dir:
@@ -36,6 +39,14 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
 
     state.job_count = len(images) * p.n_iter
 
+    # extract "default" params to use in case getting png info fails
+    prompt = p.prompt
+    negative_prompt = p.negative_prompt
+    seed = p.seed
+    cfg_scale = p.cfg_scale
+    sampler_name = p.sampler_name
+    steps = p.steps
+
     for i, image in enumerate(images):
         state.job = f"{i+1} out of {len(images)}"
         if state.skipped:
@@ -79,25 +90,45 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
             mask_image = Image.open(mask_image_path)
             p.image_mask = mask_image
 
+        if use_png_info:
+            try:
+                info_img = img
+                if png_info_dir:
+                    info_img_path = os.path.join(png_info_dir, os.path.basename(image))
+                    info_img = Image.open(info_img_path)
+                geninfo, _ = imgutil.read_info_from_image(info_img)
+                parsed_parameters = parse_generation_parameters(geninfo)
+                parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})}
+            except Exception:
+                parsed_parameters = {}
+
+            p.prompt = prompt + (" " + parsed_parameters["Prompt"] if "Prompt" in parsed_parameters else "")
+            p.negative_prompt = negative_prompt + (" " + parsed_parameters["Negative prompt"] if "Negative prompt" in parsed_parameters else "")
+            p.seed = int(parsed_parameters.get("Seed", seed))
+            p.cfg_scale = float(parsed_parameters.get("CFG scale", cfg_scale))
+            p.sampler_name = parsed_parameters.get("Sampler", sampler_name)
+            p.steps = int(parsed_parameters.get("Steps", steps))
+
         proc = modules.scripts.scripts_img2img.run(p, *args)
         if proc is None:
             proc = process_images(p)
 
         for n, processed_image in enumerate(proc.images):
-            filename = image_path.name
+            filename = image_path.stem
+            infotext = proc.infotext(p, n)
+            relpath = os.path.dirname(os.path.relpath(image, input_dir))
 
             if n > 0:
-                left, right = os.path.splitext(filename)
-                filename = f"{left}-{n}{right}"
+                filename += f"-{n}"
 
             if not save_normally:
-                os.makedirs(output_dir, exist_ok=True)
+                os.makedirs(os.path.join(output_dir, relpath), exist_ok=True)
                 if processed_image.mode == 'RGBA':
                     processed_image = processed_image.convert("RGB")
-                processed_image.save(os.path.join(output_dir, filename))
+                save_image(processed_image, os.path.join(output_dir, relpath), None, extension=opts.samples_format, info=infotext, forced_filename=filename, save_to_dirs=False)
 
 
-def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
+def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
     override_settings = create_override_settings_dict(override_settings_texts)
 
     is_batch = mode == 5
@@ -180,24 +211,25 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
     p.scripts = modules.scripts.scripts_img2img
     p.script_args = args
 
+    p.user = request.username
+
     if shared.cmd_opts.enable_console_prompts:
         print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
 
     if mask:
         p.extra_generation_params["Mask blur"] = mask_blur
 
-    if is_batch:
-        assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
+    with closing(p):
+        if is_batch:
+            assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
 
-        process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by)
-
-        processed = Processed(p, [], p.seed, "")
-    else:
-        processed = modules.scripts.scripts_img2img.run(p, *args)
-        if processed is None:
-            processed = process_images(p)
+            process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
 
-    p.close()
+            processed = Processed(p, [], p.seed, "")
+        else:
+            processed = modules.scripts.scripts_img2img.run(p, *args)
+            if processed is None:
+                processed = process_images(p)
 
     shared.total_tqdm.clear()
 

+ 1 - 2
modules/interrogate.py

@@ -184,8 +184,7 @@ class InterrogateModels:
 
     def interrogate(self, pil_image):
         res = ""
-        shared.state.begin()
-        shared.state.job = 'interrogate'
+        shared.state.begin(job="interrogate")
         try:
             if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
                 lowvram.send_everything_to_cpu()

+ 3 - 3
modules/launch_utils.py

@@ -142,15 +142,15 @@ def git_clone(url, dir, name, commithash=None):
         if commithash is None:
             return
 
-        current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
+        current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
         if current_hash == commithash:
             return
 
         run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
-        run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
+        run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
         return
 
-    run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
+    run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
 
     if commithash is not None:
         run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")

+ 31 - 8
modules/mac_specific.py

@@ -1,20 +1,43 @@
+import logging
+
 import torch
 import platform
 from modules.sd_hijack_utils import CondFunc
 from packaging import version
 
+log = logging.getLogger(__name__)
+
 
-# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
-# check `getattr` and try it for compatibility
+# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
+# use check `getattr` and try it for compatibility.
+# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availabilty,
+# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
 def check_for_mps() -> bool:
-    if not getattr(torch, 'has_mps', False):
-        return False
+    if version.parse(torch.__version__) <= version.parse("2.0.1"):
+        if not getattr(torch, 'has_mps', False):
+            return False
+        try:
+            torch.zeros(1).to(torch.device("mps"))
+            return True
+        except Exception:
+            return False
+    else:
+        return torch.backends.mps.is_available() and torch.backends.mps.is_built()
+
+
+has_mps = check_for_mps()
+
+
+def torch_mps_gc() -> None:
     try:
-        torch.zeros(1).to(torch.device("mps"))
-        return True
+        from modules.shared import state
+        if state.current_latent is not None:
+            log.debug("`current_latent` is set, skipping MPS garbage collection")
+            return
+        from torch.mps import empty_cache
+        empty_cache()
     except Exception:
-        return False
-has_mps = check_for_mps()
+        log.warning("MPS garbage collection failed", exc_info=True)
 
 
 # MPS workaround for https://github.com/pytorch/pytorch/issues/89784

+ 27 - 4
modules/modelloader.py

@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import os
 import shutil
 import importlib
@@ -8,6 +10,29 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale
 from modules.paths import script_path, models_path
 
 
+def load_file_from_url(
+    url: str,
+    *,
+    model_dir: str,
+    progress: bool = True,
+    file_name: str | None = None,
+) -> str:
+    """Download a file from `url` into `model_dir`, using the file present if possible.
+
+    Returns the path to the downloaded file.
+    """
+    os.makedirs(model_dir, exist_ok=True)
+    if not file_name:
+        parts = urlparse(url)
+        file_name = os.path.basename(parts.path)
+    cached_file = os.path.abspath(os.path.join(model_dir, file_name))
+    if not os.path.exists(cached_file):
+        print(f'Downloading: "{url}" to {cached_file}\n')
+        from torch.hub import download_url_to_file
+        download_url_to_file(url, cached_file, progress=progress)
+    return cached_file
+
+
 def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
     """
     A one-and done loader to try finding the desired models in specified directories.
@@ -46,9 +71,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
 
         if model_url is not None and len(output) == 0:
             if download_name is not None:
-                from basicsr.utils.download_util import load_file_from_url
-                dl = load_file_from_url(model_url, places[0], True, download_name)
-                output.append(dl)
+                output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name))
             else:
                 output.append(model_url)
 
@@ -59,7 +82,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
 
 
 def friendly_name(file: str):
-    if "http" in file:
+    if file.startswith("http"):
         file = urlparse(file).path
 
     file = os.path.basename(file)

+ 0 - 14
modules/paths.py

@@ -38,17 +38,3 @@ for d, must_exist, what, options in path_dirs:
         else:
             sys.path.append(d)
         paths[what] = d
-
-
-class Prioritize:
-    def __init__(self, name):
-        self.name = name
-        self.path = None
-
-    def __enter__(self):
-        self.path = sys.path.copy()
-        sys.path = [paths[self.name]] + sys.path
-
-    def __exit__(self, exc_type, exc_val, exc_tb):
-        sys.path = self.path
-        self.path = None

+ 4 - 3
modules/postprocessing.py

@@ -9,8 +9,7 @@ from modules.shared import opts
 def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
     devices.torch_gc()
 
-    shared.state.begin()
-    shared.state.job = 'extras'
+    shared.state.begin(job="extras")
 
     image_data = []
     image_names = []
@@ -54,7 +53,9 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
     for image, name in zip(image_data, image_names):
         shared.state.textinfo = name
 
-        existing_pnginfo = image.info or {}
+        parameters, existing_pnginfo = images.read_info_from_image(image)
+        if parameters:
+            existing_pnginfo["parameters"] = parameters
 
         pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
 

+ 15 - 8
modules/processing.py

@@ -184,6 +184,8 @@ class StableDiffusionProcessing:
         self.uc = None
         self.c = None
 
+        self.user = None
+
     @property
     def sd_model(self):
         return shared.sd_model
@@ -549,7 +551,7 @@ def program_version():
     return res
 
 
-def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
+def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False):
     index = position_in_batch + iteration * p.batch_size
 
     clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
@@ -573,7 +575,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
         "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
         "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
         "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
-        "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
+        "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
         "Denoising strength": getattr(p, 'denoising_strength', None),
         "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
         "Clip skip": None if clip_skip <= 1 else clip_skip,
@@ -585,13 +587,15 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
         "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
         **p.extra_generation_params,
         "Version": program_version() if opts.add_version_to_infotext else None,
+        "User": p.user if opts.add_user_name_to_info else None,
     }
 
     generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
 
+    prompt_text = p.prompt if use_main_prompt else all_prompts[index]
     negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
 
-    return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
+    return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
 
 
 def process_images(p: StableDiffusionProcessing) -> Processed:
@@ -602,7 +606,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
 
     try:
         # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
-        if sd_models.checkpoint_alisases.get(p.override_settings.get('sd_model_checkpoint')) is None:
+        if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
             p.override_settings.pop('sd_model_checkpoint', None)
             sd_models.reload_model_weights()
 
@@ -663,8 +667,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
     else:
         p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
 
-    def infotext(iteration=0, position_in_batch=0):
-        return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
+    def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):
+        return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
 
     if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
         model_hijack.embedding_db.load_textual_inversion_embeddings()
@@ -824,7 +828,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
             grid = images.image_grid(output_images, p.batch_size)
 
             if opts.return_grid:
-                text = infotext()
+                text = infotext(use_main_prompt=True)
                 infotexts.insert(0, text)
                 if opts.enable_pnginfo:
                     grid.info["parameters"] = text
@@ -832,7 +836,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
                 index_of_first_image = 1
 
             if opts.grid_save:
-                images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
+                images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
 
     if not p.disable_extra_networks and p.extra_network_data:
         extra_networks.deactivate(p, p.extra_network_data)
@@ -1074,6 +1078,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
 
         sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
 
+        if self.scripts is not None:
+            self.scripts.before_hr(self)
+
         samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
 
         sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())

+ 15 - 18
modules/realesrgan_model.py

@@ -2,7 +2,6 @@ import os
 
 import numpy as np
 from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
 from realesrgan import RealESRGANer
 
 from modules.upscaler import Upscaler, UpscalerData
@@ -43,9 +42,10 @@ class UpscalerRealESRGAN(Upscaler):
         if not self.enable:
             return img
 
-        info = self.load_model(path)
-        if not os.path.exists(info.local_data_path):
-            print(f"Unable to load RealESRGAN model: {info.name}")
+        try:
+            info = self.load_model(path)
+        except Exception:
+            errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
             return img
 
         upsampler = RealESRGANer(
@@ -63,20 +63,17 @@ class UpscalerRealESRGAN(Upscaler):
         return image
 
     def load_model(self, path):
-        try:
-            info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)
-
-            if info is None:
-                print(f"Unable to find model info: {path}")
-                return None
-
-            if info.local_data_path.startswith("http"):
-                info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True)
-
-            return info
-        except Exception:
-            errors.report("Error making Real-ESRGAN models list", exc_info=True)
-        return None
+        for scaler in self.scalers:
+            if scaler.data_path == path:
+                if scaler.local_data_path.startswith("http"):
+                    scaler.local_data_path = modelloader.load_file_from_url(
+                        scaler.data_path,
+                        model_dir=self.model_download_path,
+                    )
+                if not os.path.exists(scaler.local_data_path):
+                    raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}")
+                return scaler
+        raise ValueError(f"Unable to find model info: {path}")
 
     def load_models(self, _):
         return get_realesrgan_models(self)

+ 39 - 1
modules/scripts.py

@@ -1,6 +1,7 @@
 import os
 import re
 import sys
+import inspect
 from collections import namedtuple
 
 import gradio as gr
@@ -116,6 +117,21 @@ class Script:
 
         pass
 
+    def after_extra_networks_activate(self, p, *args, **kwargs):
+        """
+        Calledafter extra networks activation, before conds calculation
+        allow modification of the network after extra networks activation been applied
+        won't be call if p.disable_extra_networks
+
+        **kwargs will have those items:
+          - batch_number - index of current batch, from 0 to number of batches-1
+          - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
+          - seeds - list of seeds for current batch
+          - subseeds - list of subseeds for current batch
+          - extra_network_data - list of ExtraNetworkParams for current stage
+        """
+        pass
+
     def process_batch(self, p, *args, **kwargs):
         """
         Same as process(), but called for every batch.
@@ -186,6 +202,11 @@ class Script:
 
         return f'script_{tabname}{title}_{item_id}'
 
+    def before_hr(self, p, *args):
+        """
+        This function is called before hires fix start.
+        """
+        pass
 
 current_basedir = paths.script_path
 
@@ -249,7 +270,7 @@ def load_scripts():
 
     def register_scripts_from_module(module):
         for script_class in module.__dict__.values():
-            if type(script_class) != type:
+            if not inspect.isclass(script_class):
                 continue
 
             if issubclass(script_class, Script):
@@ -483,6 +504,14 @@ class ScriptRunner:
             except Exception:
                 errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
 
+    def after_extra_networks_activate(self, p, **kwargs):
+        for script in self.alwayson_scripts:
+            try:
+                script_args = p.script_args[script.args_from:script.args_to]
+                script.after_extra_networks_activate(p, *script_args, **kwargs)
+            except Exception:
+                errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True)
+
     def process_batch(self, p, **kwargs):
         for script in self.alwayson_scripts:
             try:
@@ -548,6 +577,15 @@ class ScriptRunner:
                     self.scripts[si].args_to = args_to
 
 
+    def before_hr(self, p):
+        for script in self.alwayson_scripts:
+            try:
+                script_args = p.script_args[script.args_from:script.args_to]
+                script.before_hr(p, *script_args)
+            except Exception:
+                errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
+
+
 scripts_txt2img: ScriptRunner = None
 scripts_img2img: ScriptRunner = None
 scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None

+ 12 - 7
modules/sd_models.py

@@ -23,7 +23,8 @@ model_dir = "Stable-diffusion"
 model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
 
 checkpoints_list = {}
-checkpoint_alisases = {}
+checkpoint_aliases = {}
+checkpoint_alisases = checkpoint_aliases  # for compatibility with old name
 checkpoints_loaded = collections.OrderedDict()
 
 
@@ -66,7 +67,7 @@ class CheckpointInfo:
     def register(self):
         checkpoints_list[self.title] = self
         for id in self.ids:
-            checkpoint_alisases[id] = self
+            checkpoint_aliases[id] = self
 
     def calculate_shorthash(self):
         self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
@@ -112,7 +113,7 @@ def checkpoint_tiles():
 
 def list_models():
     checkpoints_list.clear()
-    checkpoint_alisases.clear()
+    checkpoint_aliases.clear()
 
     cmd_ckpt = shared.cmd_opts.ckpt
     if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):
@@ -136,7 +137,7 @@ def list_models():
 
 
 def get_closet_checkpoint_match(search_string):
-    checkpoint_info = checkpoint_alisases.get(search_string, None)
+    checkpoint_info = checkpoint_aliases.get(search_string, None)
     if checkpoint_info is not None:
         return checkpoint_info
 
@@ -166,7 +167,7 @@ def select_checkpoint():
     """Raises `FileNotFoundError` if no checkpoints are found."""
     model_checkpoint = shared.opts.sd_model_checkpoint
 
-    checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
+    checkpoint_info = checkpoint_aliases.get(model_checkpoint, None)
     if checkpoint_info is not None:
         return checkpoint_info
 
@@ -247,7 +248,12 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
     _, extension = os.path.splitext(checkpoint_file)
     if extension.lower() == ".safetensors":
         device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
-        pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
+
+        if not shared.opts.disable_mmap_load_safetensors:
+            pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
+        else:
+            pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
+            pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
     else:
         pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
 
@@ -585,7 +591,6 @@ def unload_model_weights(sd_model=None, info=None):
         sd_model = None
         gc.collect()
         devices.torch_gc()
-        torch.cuda.empty_cache()
 
     print(f"Unloaded weights {timer.summary()}.")
 

+ 29 - 6
modules/shared.py

@@ -1,9 +1,11 @@
 import datetime
 import json
 import os
+import re
 import sys
 import threading
 import time
+import logging
 
 import gradio as gr
 import torch
@@ -18,6 +20,8 @@ from modules.paths_internal import models_path, script_path, data_path, sd_confi
 from ldm.models.diffusion.ddpm import LatentDiffusion
 from typing import Optional
 
+log = logging.getLogger(__name__)
+
 demo = None
 
 parser = cmd_args.parser
@@ -144,12 +148,15 @@ class State:
     def request_restart(self) -> None:
         self.interrupt()
         self.server_command = "restart"
+        log.info("Received restart request")
 
     def skip(self):
         self.skipped = True
+        log.info("Received skip request")
 
     def interrupt(self):
         self.interrupted = True
+        log.info("Received interrupt request")
 
     def nextjob(self):
         if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
@@ -173,7 +180,7 @@ class State:
 
         return obj
 
-    def begin(self):
+    def begin(self, job: str = "(unknown)"):
         self.sampling_step = 0
         self.job_count = -1
         self.processing_has_refined_job_count = False
@@ -187,10 +194,13 @@ class State:
         self.interrupted = False
         self.textinfo = None
         self.time_start = time.time()
-
+        self.job = job
         devices.torch_gc()
+        log.info("Starting job %s", job)
 
     def end(self):
+        duration = time.time() - self.time_start
+        log.info("Ending job %s (%.2f seconds)", self.job, duration)
         self.job = ""
         self.job_count = 0
 
@@ -311,6 +321,10 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
     "grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
     "grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
     "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
+    "font": OptionInfo("", "Font for image grids that have text"),
+    "grid_text_active_color": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}),
+    "grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}),
+    "grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}),
 
     "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
     "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
@@ -376,6 +390,7 @@ options_templates.update(options_section(('system', "System"), {
     "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
     "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
     "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
+    "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
 }))
 
 options_templates.update(options_section(('training', "Training"), {
@@ -470,7 +485,6 @@ options_templates.update(options_section(('ui', "User interface"), {
     "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
     "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
     "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
-    "font": OptionInfo("", "Font for image grids that have text"),
     "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
     "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
     "js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
@@ -481,6 +495,7 @@ options_templates.update(options_section(('ui', "User interface"), {
     "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
     "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
     "keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
+    "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
     "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(),
     "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
     "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
@@ -493,6 +508,7 @@ options_templates.update(options_section(('ui', "User interface"), {
 options_templates.update(options_section(('infotext', "Infotext"), {
     "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
     "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
+    "add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
     "add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
     "disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
     "infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
@@ -817,8 +833,12 @@ mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
 mem_mon.start()
 
 
+def natural_sort_key(s, regex=re.compile('([0-9]+)')):
+    return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)]
+
+
 def listfiles(dirname):
-    filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=str.lower) if not x.startswith(".")]
+    filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")]
     return [file for file in filenames if os.path.isfile(file)]
 
 
@@ -843,8 +863,11 @@ def walk_files(path, allowed_extensions=None):
     if allowed_extensions is not None:
         allowed_extensions = set(allowed_extensions)
 
-    for root, _, files in os.walk(path, followlinks=True):
-        for filename in files:
+    items = list(os.walk(path, followlinks=True))
+    items = sorted(items, key=lambda x: natural_sort_key(x[0]))
+
+    for root, _, files in items:
+        for filename in sorted(files, key=natural_sort_key):
             if allowed_extensions is not None:
                 _, ext = os.path.splitext(filename)
                 if ext not in allowed_extensions:

+ 44 - 4
modules/textual_inversion/logging.py

@@ -2,11 +2,51 @@ import datetime
 import json
 import os
 
-saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "gradient_step", "latent_sampling_method"}
-saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
-saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
+saved_params_shared = {
+    "batch_size",
+    "clip_grad_mode",
+    "clip_grad_value",
+    "create_image_every",
+    "data_root",
+    "gradient_step",
+    "initial_step",
+    "latent_sampling_method",
+    "learn_rate",
+    "log_directory",
+    "model_hash",
+    "model_name",
+    "num_of_dataset_images",
+    "steps",
+    "template_file",
+    "training_height",
+    "training_width",
+}
+saved_params_ti = {
+    "embedding_name",
+    "num_vectors_per_token",
+    "save_embedding_every",
+    "save_image_with_stored_embedding",
+}
+saved_params_hypernet = {
+    "activation_func",
+    "add_layer_norm",
+    "hypernetwork_name",
+    "layer_structure",
+    "save_hypernetwork_every",
+    "use_dropout",
+    "weight_init",
+}
 saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
-saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"}
+saved_params_previews = {
+    "preview_cfg_scale",
+    "preview_height",
+    "preview_negative_prompt",
+    "preview_prompt",
+    "preview_sampler_index",
+    "preview_seed",
+    "preview_steps",
+    "preview_width",
+}
 
 
 def save_settings_to_file(log_directory, all_params):

+ 1 - 1
modules/textual_inversion/preprocess.py

@@ -7,7 +7,7 @@ from modules import paths, shared, images, deepbooru
 from modules.textual_inversion import autocrop
 
 
-def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
+def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.15, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
     try:
         if process_caption:
             shared.interrogator.load()

+ 4 - 2
modules/textual_inversion/textual_inversion.py

@@ -1,5 +1,6 @@
 import os
 from collections import namedtuple
+from contextlib import closing
 
 import torch
 import tqdm
@@ -584,8 +585,9 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
 
                     preview_text = p.prompt
 
-                    processed = processing.process_images(p)
-                    image = processed.images[0] if len(processed.images) > 0 else None
+                    with closing(p):
+                        processed = processing.process_images(p)
+                        image = processed.images[0] if len(processed.images) > 0 else None
 
                     if unload:
                         shared.sd_model.first_stage_model.to(devices.cpu)

+ 10 - 7
modules/txt2img.py

@@ -1,13 +1,15 @@
+from contextlib import closing
+
 import modules.scripts
 from modules import sd_samplers, processing
 from modules.generation_parameters_copypaste import create_override_settings_dict
 from modules.shared import opts, cmd_opts
 import modules.shared as shared
 from modules.ui import plaintext_to_html
+import gradio as gr
 
 
-
-def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args):
+def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
     override_settings = create_override_settings_dict(override_settings_texts)
 
     p = processing.StableDiffusionProcessingTxt2Img(
@@ -48,15 +50,16 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
     p.scripts = modules.scripts.scripts_txt2img
     p.script_args = args
 
+    p.user = request.username
+
     if cmd_opts.enable_console_prompts:
         print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
 
-    processed = modules.scripts.scripts_txt2img.run(p, *args)
-
-    if processed is None:
-        processed = processing.process_images(p)
+    with closing(p):
+        processed = modules.scripts.scripts_txt2img.run(p, *args)
 
-    p.close()
+        if processed is None:
+            processed = processing.process_images(p)
 
     shared.total_tqdm.clear()
 

+ 10 - 3
modules/ui.py

@@ -155,7 +155,7 @@ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_di
             img = Image.open(image)
             filename = os.path.basename(image)
             left, _ = os.path.splitext(filename)
-            print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a'))
+            print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a', encoding='utf-8'))
 
         return [gr.update(), None]
 
@@ -733,6 +733,10 @@ def create_ui():
                         img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
                         img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
                         img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
+                        with gr.Accordion("PNG info", open=False):
+                            img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info")
+                            img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir")
+                            img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.")
 
                     img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
 
@@ -773,7 +777,7 @@ def create_ui():
                                 selected_scale_tab = gr.State(value=0)
 
                                 with gr.Tabs():
-                                    with gr.Tab(label="Resize to") as tab_scale_to:
+                                    with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to:
                                         with FormRow():
                                             with gr.Column(elem_id="img2img_column_size", scale=4):
                                                 width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
@@ -782,7 +786,7 @@ def create_ui():
                                                 res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
                                                 detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn")
 
-                                    with gr.Tab(label="Resize by") as tab_scale_by:
+                                    with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by:
                                         scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
 
                                         with FormRow():
@@ -934,6 +938,9 @@ def create_ui():
                     img2img_batch_output_dir,
                     img2img_batch_inpaint_mask_dir,
                     override_settings,
+                    img2img_batch_use_png_info,
+                    img2img_batch_png_info_props,
+                    img2img_batch_png_info_dir,
                 ] + custom_inputs,
                 outputs=[
                     img2img_gallery,

+ 23 - 6
modules/ui_extensions.py

@@ -138,7 +138,10 @@ def extension_table():
     <table id="extensions">
         <thead>
             <tr>
-                <th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th>
+                <th>
+                    <input class="gr-check-radio gr-checkbox all_extensions_toggle" type="checkbox" {'checked="checked"' if all(ext.enabled for ext in extensions.extensions) else ''} onchange="toggle_all_extensions(event)" />
+                    <abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr>
+                </th>
                 <th>URL</th>
                 <th>Branch</th>
                 <th>Version</th>
@@ -170,7 +173,7 @@ def extension_table():
 
         code += f"""
             <tr>
-                <td><label{style}><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
+                <td><label{style}><input class="gr-check-radio gr-checkbox extension_toggle" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''} onchange="toggle_extension(event)" />{html.escape(ext.name)}</label></td>
                 <td>{remote}</td>
                 <td>{ext.branch}</td>
                 <td>{version_link}</td>
@@ -421,9 +424,19 @@ sort_ordering = [
     (False, lambda x: x.get('name', 'z')),
     (True, lambda x: x.get('name', 'z')),
     (False, lambda x: 'z'),
+    (True, lambda x: x.get('commit_time', '')),
+    (True, lambda x: x.get('created_at', '')),
+    (True, lambda x: x.get('stars', 0)),
 ]
 
 
+def get_date(info: dict, key):
+    try:
+        return datetime.strptime(info.get(key), "%Y-%m-%dT%H:%M:%SZ").strftime("%Y-%m-%d")
+    except (ValueError, TypeError):
+        return ''
+
+
 def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""):
     extlist = available_extensions["extensions"]
     installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
@@ -448,7 +461,10 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
 
     for ext in sorted(extlist, key=sort_function, reverse=sort_reverse):
         name = ext.get("name", "noname")
+        stars = int(ext.get("stars", 0))
         added = ext.get('added', 'unknown')
+        update_time = get_date(ext, 'commit_time')
+        create_time = get_date(ext, 'created_at')
         url = ext.get("url", None)
         description = ext.get("description", "")
         extension_tags = ext.get("tags", [])
@@ -475,7 +491,8 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
         code += f"""
             <tr>
                 <td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td>
-                <td>{html.escape(description)}<p class="info"><span class="date_added">Added: {html.escape(added)}</span></p></td>
+                <td>{html.escape(description)}<p class="info">
+                <span class="date_added">Update: {html.escape(update_time)}  Added: {html.escape(added)}  Created: {html.escape(create_time)}</span><span class="star_count">stars: <b>{stars}</b></a></p></td>
                 <td>{install_code}</td>
             </tr>
 
@@ -559,7 +576,7 @@ def create_ui():
 
                 with gr.Row():
                     hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
-                    sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index")
+                    sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index")
 
                 with gr.Row():
                     search_extensions_text = gr.Text(label="Search").style(container=False)
@@ -568,9 +585,9 @@ def create_ui():
                 available_extensions_table = gr.HTML()
 
                 refresh_available_extensions_button.click(
-                    fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
+                    fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update(), gr.update()]),
                     inputs=[available_extensions_index, hide_tags, sort_column],
-                    outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result, search_extensions_text],
+                    outputs=[available_extensions_index, available_extensions_table, hide_tags, search_extensions_text, install_result],
                 )
 
                 install_extension_button.click(

+ 4 - 4
modules/ui_extra_networks.py

@@ -30,8 +30,8 @@ def fetch_file(filename: str = ""):
         raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
 
     ext = os.path.splitext(filename)[1].lower()
-    if ext not in (".png", ".jpg", ".jpeg", ".webp"):
-        raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
+    if ext not in (".png", ".jpg", ".jpeg", ".webp", ".gif"):
+        raise ValueError(f"File cannot be fetched: {filename}. Only png, jpg, webp, and gif.")
 
     # would profit from returning 304
     return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
@@ -90,8 +90,8 @@ class ExtraNetworksPage:
 
         subdirs = {}
         for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
-            for root, dirs, _ in os.walk(parentdir, followlinks=True):
-                for dirname in dirs:
+            for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])):
+                for dirname in sorted(dirs, key=shared.natural_sort_key):
                     x = os.path.join(root, dirname)
 
                     if not os.path.isdir(x):

+ 14 - 7
modules/ui_settings.py

@@ -260,13 +260,20 @@ class UiSettings:
             component = self.component_dict[k]
             info = opts.data_labels[k]
 
-            change_handler = component.release if hasattr(component, 'release') else component.change
-            change_handler(
-                fn=lambda value, k=k: self.run_settings_single(value, key=k),
-                inputs=[component],
-                outputs=[component, self.text_settings],
-                show_progress=info.refresh is not None,
-            )
+            if isinstance(component, gr.Textbox):
+                methods = [component.submit, component.blur]
+            elif hasattr(component, 'release'):
+                methods = [component.release]
+            else:
+                methods = [component.change]
+
+            for method in methods:
+                method(
+                    fn=lambda value, k=k: self.run_settings_single(value, key=k),
+                    inputs=[component],
+                    outputs=[component, self.text_settings],
+                    show_progress=info.refresh is not None,
+                )
 
         button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
         button_set_checkpoint.click(

+ 15 - 2
style.css

@@ -704,11 +704,24 @@ table.popup-table .link{
     margin: 0;
 }
 
-#available_extensions .date_added{
-    opacity: 0.85;
+#available_extensions .info{
+    margin: 0.5em 0;
+    display: flex;
+    margin-top: auto;
+    opacity: 0.80;
     font-size: 90%;
 }
 
+#available_extensions .date_added{
+    margin-right: auto;
+    display: inline-block;
+}
+
+#available_extensions .star_count{
+    margin-left: auto;
+    display: inline-block;
+}
+
 /* replace original footer with ours */
 
 footer {

+ 19 - 15
webui.py

@@ -11,13 +11,24 @@ import json
 from threading import Thread
 from typing import Iterable
 
-from fastapi import FastAPI, Response
+from fastapi import FastAPI
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.gzip import GZipMiddleware
 from packaging import version
 
 import logging
 
+# We can't use cmd_opts for this because it will not have been initialized at this point.
+log_level = os.environ.get("SD_WEBUI_LOG_LEVEL")
+if log_level:
+    log_level = getattr(logging, log_level.upper(), None) or logging.INFO
+    logging.basicConfig(
+        level=log_level,
+        format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
+        datefmt='%Y-%m-%d %H:%M:%S',
+    )
+
+logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR)  # sshh...
 logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
 
 from modules import paths, timer, import_hook, errors, devices  # noqa: F401
@@ -32,7 +43,7 @@ warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvisi
 
 startup_timer.record("import torch")
 
-import gradio
+import gradio  # noqa: F401
 startup_timer.record("import gradio")
 
 import ldm.modules.encoders.modules  # noqa: F401
@@ -359,12 +370,11 @@ def api_only():
     modules.script_callbacks.app_started_callback(None, app)
 
     print(f"Startup time: {startup_timer.summary()}.")
-    api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
-
-
-def stop_route(request):
-    shared.state.server_command = "stop"
-    return Response("Stopping.")
+    api.launch(
+        server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1",
+        port=cmd_opts.port if cmd_opts.port else 7861,
+        root_path = f"/{cmd_opts.subpath}"
+    )
 
 
 def webui():
@@ -403,9 +413,8 @@ def webui():
                 "docs_url": "/docs",
                 "redoc_url": "/redoc",
             },
+            root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "",
         )
-        if cmd_opts.add_stop_route:
-            app.add_route("/_stop", stop_route, methods=["POST"])
 
         # after initial launch, disable --autolaunch for subsequent restarts
         cmd_opts.autolaunch = False
@@ -436,11 +445,6 @@ def webui():
         timer.startup_record = startup_timer.dump()
         print(f"Startup time: {startup_timer.summary()}.")
 
-        if cmd_opts.subpath:
-            redirector = FastAPI()
-            redirector.get("/")
-            gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}")
-
         try:
             while True:
                 server_command = shared.state.wait_for_server_command(timeout=5)

+ 11 - 5
webui.sh

@@ -4,26 +4,28 @@
 # change the variables in webui-user.sh instead #
 #################################################
 
+SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
+
 # If run from macOS, load defaults from webui-macos-env.sh
 if [[ "$OSTYPE" == "darwin"* ]]; then
-    if [[ -f webui-macos-env.sh ]]
+    if [[ -f "$SCRIPT_DIR"/webui-macos-env.sh ]]
         then
-        source ./webui-macos-env.sh
+        source "$SCRIPT_DIR"/webui-macos-env.sh
     fi
 fi
 
 # Read variables from webui-user.sh
 # shellcheck source=/dev/null
-if [[ -f webui-user.sh ]]
+if [[ -f "$SCRIPT_DIR"/webui-user.sh ]]
 then
-    source ./webui-user.sh
+    source "$SCRIPT_DIR"/webui-user.sh
 fi
 
 # Set defaults
 # Install directory without trailing slash
 if [[ -z "${install_dir}" ]]
 then
-    install_dir="$(pwd)"
+    install_dir="$SCRIPT_DIR"
 fi
 
 # Name of the subdirectory (defaults to stable-diffusion-webui)
@@ -131,6 +133,10 @@ case "$gpu_info" in
     ;;
     *"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0
     ;;
+    *"Navi 3"*) [[ -z "${TORCH_COMMAND}" ]] && \
+        export TORCH_COMMAND="pip install --pre torch==2.1.0.dev-20230614+rocm5.5 torchvision==0.16.0.dev-20230614+rocm5.5 --index-url https://download.pytorch.org/whl/nightly/rocm5.5"
+        # Navi 3 needs at least 5.5 which is only on the nightly chain
+    ;;
     *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0
         printf "\n%s\n" "${delimiter}"
         printf "Experimental support for Renoir: make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half"