Просмотр исходного кода

Merge branch 'dev' into master

AUTOMATIC1111 2 лет назад
Родитель
Сommit
b7c5b30f14
49 измененных файлов с 705 добавлено и 390 удалено
  1. 1 1
      .github/workflows/on_pull_request.yaml
  2. 2 2
      .github/workflows/run_tests.yaml
  3. 3 0
      README.md
  4. 3 5
      extensions-builtin/LDSR/ldsr_model_arch.py
  5. 8 12
      extensions-builtin/LDSR/scripts/ldsr_model.py
  6. 1 1
      extensions-builtin/Lora/lora.py
  7. 13 17
      extensions-builtin/ScuNET/scripts/scunet_model.py
  8. 49 34
      extensions-builtin/SwinIR/scripts/swinir_model.py
  9. 29 1
      extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js
  10. 1 0
      extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py
  11. 3 2
      javascript/edit-attention.js
  12. 41 0
      javascript/edit-order.js
  13. 18 0
      javascript/extensions.js
  14. 73 56
      modules/api/api.py
  15. 0 4
      modules/api/models.py
  16. 4 1
      modules/call_queue.py
  17. 1 0
      modules/cmd_args.py
  18. 1 5
      modules/codeformer_model.py
  19. 4 7
      modules/devices.py
  20. 10 13
      modules/esrgan_model.py
  21. 3 0
      modules/extra_networks.py
  22. 1 2
      modules/extras.py
  23. 0 29
      modules/generation_parameters_copypaste.py
  24. 1 1
      modules/gfpgan_model.py
  25. 4 26
      modules/hypernetworks/hypernetwork.py
  26. 48 21
      modules/images.py
  27. 52 20
      modules/img2img.py
  28. 1 2
      modules/interrogate.py
  29. 3 3
      modules/launch_utils.py
  30. 31 8
      modules/mac_specific.py
  31. 27 4
      modules/modelloader.py
  32. 0 14
      modules/paths.py
  33. 4 3
      modules/postprocessing.py
  34. 15 8
      modules/processing.py
  35. 15 18
      modules/realesrgan_model.py
  36. 39 1
      modules/scripts.py
  37. 12 7
      modules/sd_models.py
  38. 29 6
      modules/shared.py
  39. 44 4
      modules/textual_inversion/logging.py
  40. 1 1
      modules/textual_inversion/preprocess.py
  41. 4 2
      modules/textual_inversion/textual_inversion.py
  42. 10 7
      modules/txt2img.py
  43. 10 3
      modules/ui.py
  44. 23 6
      modules/ui_extensions.py
  45. 4 4
      modules/ui_extra_networks.py
  46. 14 7
      modules/ui_settings.py
  47. 15 2
      style.css
  48. 19 15
      webui.py
  49. 11 5
      webui.sh

+ 1 - 1
.github/workflows/on_pull_request.yaml

@@ -18,7 +18,7 @@ jobs:
           #     not to have GHA download an (at the time of writing) 4 GB cache
           #     not to have GHA download an (at the time of writing) 4 GB cache
           #     of PyTorch and other dependencies.
           #     of PyTorch and other dependencies.
       - name: Install Ruff
       - name: Install Ruff
-        run: pip install ruff==0.0.265
+        run: pip install ruff==0.0.272
       - name: Run Ruff
       - name: Run Ruff
         run: ruff .
         run: ruff .
   lint-js:
   lint-js:

+ 2 - 2
.github/workflows/run_tests.yaml

@@ -42,7 +42,7 @@ jobs:
           --no-half
           --no-half
           --disable-opt-split-attention
           --disable-opt-split-attention
           --use-cpu all
           --use-cpu all
-          --add-stop-route
+          --api-server-stop
           2>&1 | tee output.txt &
           2>&1 | tee output.txt &
       - name: Run tests
       - name: Run tests
         run: |
         run: |
@@ -50,7 +50,7 @@ jobs:
           python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
           python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
       - name: Kill test server
       - name: Kill test server
         if: always()
         if: always()
-        run: curl -vv -XPOST http://127.0.0.1:7860/_stop && sleep 10
+        run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-stop && sleep 10
       - name: Show coverage
       - name: Show coverage
         run: |
         run: |
           python -m coverage combine .coverage*
           python -m coverage combine .coverage*

+ 3 - 0
README.md

@@ -135,8 +135,11 @@ Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-w
 Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing)
 Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing)
 
 
 ## Documentation
 ## Documentation
+
 The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki).
 The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki).
 
 
+For the purposes of getting Google and other search engines to crawl the wiki, here's a link to the (not for humans) [crawlable wiki](https://github-wiki-see.page/m/AUTOMATIC1111/stable-diffusion-webui/wiki).
+
 ## Credits
 ## Credits
 Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
 Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
 
 

+ 3 - 5
extensions-builtin/LDSR/ldsr_model_arch.py

@@ -12,7 +12,7 @@ import safetensors.torch
 
 
 from ldm.models.diffusion.ddim import DDIMSampler
 from ldm.models.diffusion.ddim import DDIMSampler
 from ldm.util import instantiate_from_config, ismap
 from ldm.util import instantiate_from_config, ismap
-from modules import shared, sd_hijack
+from modules import shared, sd_hijack, devices
 
 
 cached_ldsr_model: torch.nn.Module = None
 cached_ldsr_model: torch.nn.Module = None
 
 
@@ -112,8 +112,7 @@ class LDSR:
 
 
 
 
         gc.collect()
         gc.collect()
-        if torch.cuda.is_available:
-            torch.cuda.empty_cache()
+        devices.torch_gc()
 
 
         im_og = image
         im_og = image
         width_og, height_og = im_og.size
         width_og, height_og = im_og.size
@@ -150,8 +149,7 @@ class LDSR:
 
 
         del model
         del model
         gc.collect()
         gc.collect()
-        if torch.cuda.is_available:
-            torch.cuda.empty_cache()
+        devices.torch_gc()
 
 
         return a
         return a
 
 

+ 8 - 12
extensions-builtin/LDSR/scripts/ldsr_model.py

@@ -1,7 +1,6 @@
 import os
 import os
 
 
-from basicsr.utils.download_util import load_file_from_url
-
+from modules.modelloader import load_file_from_url
 from modules.upscaler import Upscaler, UpscalerData
 from modules.upscaler import Upscaler, UpscalerData
 from ldsr_model_arch import LDSR
 from ldsr_model_arch import LDSR
 from modules import shared, script_callbacks, errors
 from modules import shared, script_callbacks, errors
@@ -43,20 +42,17 @@ class UpscalerLDSR(Upscaler):
         if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
         if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
             model = local_safetensors_path
             model = local_safetensors_path
         else:
         else:
-            model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="model.ckpt", progress=True)
+            model = local_ckpt_path or load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name="model.ckpt")
 
 
-        yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml", progress=True)
+        yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml")
 
 
-        try:
-            return LDSR(model, yaml)
-        except Exception:
-            errors.report("Error importing LDSR", exc_info=True)
-        return None
+        return LDSR(model, yaml)
 
 
     def do_upscale(self, img, path):
     def do_upscale(self, img, path):
-        ldsr = self.load_model(path)
-        if ldsr is None:
-            print("NO LDSR!")
+        try:
+            ldsr = self.load_model(path)
+        except Exception:
+            errors.report(f"Failed loading LDSR model {path}", exc_info=True)
             return img
             return img
         ddim_steps = shared.opts.ldsr_steps
         ddim_steps = shared.opts.ldsr_steps
         return ldsr.super_resolution(img, ddim_steps, self.scale)
         return ldsr.super_resolution(img, ddim_steps, self.scale)

+ 1 - 1
extensions-builtin/Lora/lora.py

@@ -443,7 +443,7 @@ def list_available_loras():
     os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
     os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
 
 
     candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
     candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
-    for filename in sorted(candidates, key=str.lower):
+    for filename in candidates:
         if os.path.isdir(filename):
         if os.path.isdir(filename):
             continue
             continue
 
 

+ 13 - 17
extensions-builtin/ScuNET/scripts/scunet_model.py

@@ -1,4 +1,3 @@
-import os.path
 import sys
 import sys
 
 
 import PIL.Image
 import PIL.Image
@@ -6,12 +5,11 @@ import numpy as np
 import torch
 import torch
 from tqdm import tqdm
 from tqdm import tqdm
 
 
-from basicsr.utils.download_util import load_file_from_url
-
 import modules.upscaler
 import modules.upscaler
 from modules import devices, modelloader, script_callbacks, errors
 from modules import devices, modelloader, script_callbacks, errors
-from scunet_model_arch import SCUNet as net
+from scunet_model_arch import SCUNet
 
 
+from modules.modelloader import load_file_from_url
 from modules.shared import opts
 from modules.shared import opts
 
 
 
 
@@ -28,7 +26,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
         scalers = []
         scalers = []
         add_model2 = True
         add_model2 = True
         for file in model_paths:
         for file in model_paths:
-            if "http" in file:
+            if file.startswith("http"):
                 name = self.model_name
                 name = self.model_name
             else:
             else:
                 name = modelloader.friendly_name(file)
                 name = modelloader.friendly_name(file)
@@ -87,11 +85,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
 
 
     def do_upscale(self, img: PIL.Image.Image, selected_file):
     def do_upscale(self, img: PIL.Image.Image, selected_file):
 
 
-        torch.cuda.empty_cache()
+        devices.torch_gc()
 
 
-        model = self.load_model(selected_file)
-        if model is None:
-            print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr)
+        try:
+            model = self.load_model(selected_file)
+        except Exception as e:
+            print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
             return img
             return img
 
 
         device = devices.get_device_for('scunet')
         device = devices.get_device_for('scunet')
@@ -111,7 +110,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
         torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
         torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
         np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
         np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
         del torch_img, torch_output
         del torch_img, torch_output
-        torch.cuda.empty_cache()
+        devices.torch_gc()
 
 
         output = np_output.transpose((1, 2, 0))  # CHW to HWC
         output = np_output.transpose((1, 2, 0))  # CHW to HWC
         output = output[:, :, ::-1]  # BGR to RGB
         output = output[:, :, ::-1]  # BGR to RGB
@@ -119,15 +118,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
 
 
     def load_model(self, path: str):
     def load_model(self, path: str):
         device = devices.get_device_for('scunet')
         device = devices.get_device_for('scunet')
-        if "http" in path:
-            filename = load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="%s.pth" % self.name, progress=True)
+        if path.startswith("http"):
+            # TODO: this doesn't use `path` at all?
+            filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
         else:
         else:
             filename = path
             filename = path
-        if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
-            print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
-            return None
-
-        model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
+        model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
         model.load_state_dict(torch.load(filename), strict=True)
         model.load_state_dict(torch.load(filename), strict=True)
         model.eval()
         model.eval()
         for _, v in model.named_parameters():
         for _, v in model.named_parameters():

+ 49 - 34
extensions-builtin/SwinIR/scripts/swinir_model.py

@@ -1,34 +1,35 @@
-import os
+import sys
+import platform
 
 
 import numpy as np
 import numpy as np
 import torch
 import torch
 from PIL import Image
 from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
 from tqdm import tqdm
 from tqdm import tqdm
 
 
 from modules import modelloader, devices, script_callbacks, shared
 from modules import modelloader, devices, script_callbacks, shared
 from modules.shared import opts, state
 from modules.shared import opts, state
-from swinir_model_arch import SwinIR as net
-from swinir_model_arch_v2 import Swin2SR as net2
+from swinir_model_arch import SwinIR
+from swinir_model_arch_v2 import Swin2SR
 from modules.upscaler import Upscaler, UpscalerData
 from modules.upscaler import Upscaler, UpscalerData
 
 
+SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
 
 
 device_swinir = devices.get_device_for('swinir')
 device_swinir = devices.get_device_for('swinir')
 
 
 
 
 class UpscalerSwinIR(Upscaler):
 class UpscalerSwinIR(Upscaler):
     def __init__(self, dirname):
     def __init__(self, dirname):
+        self._cached_model = None           # keep the model when SWIN_torch_compile is on to prevent re-compile every runs
+        self._cached_model_config = None    # to clear '_cached_model' when changing model (v1/v2) or settings
         self.name = "SwinIR"
         self.name = "SwinIR"
-        self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \
-                         "/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
-                         "-L_x4_GAN.pth "
+        self.model_url = SWINIR_MODEL_URL
         self.model_name = "SwinIR 4x"
         self.model_name = "SwinIR 4x"
         self.user_path = dirname
         self.user_path = dirname
         super().__init__()
         super().__init__()
         scalers = []
         scalers = []
         model_files = self.find_models(ext_filter=[".pt", ".pth"])
         model_files = self.find_models(ext_filter=[".pt", ".pth"])
         for model in model_files:
         for model in model_files:
-            if "http" in model:
+            if model.startswith("http"):
                 name = self.model_name
                 name = self.model_name
             else:
             else:
                 name = modelloader.friendly_name(model)
                 name = modelloader.friendly_name(model)
@@ -37,42 +38,54 @@ class UpscalerSwinIR(Upscaler):
         self.scalers = scalers
         self.scalers = scalers
 
 
     def do_upscale(self, img, model_file):
     def do_upscale(self, img, model_file):
-        model = self.load_model(model_file)
-        if model is None:
-            return img
-        model = model.to(device_swinir, dtype=devices.dtype)
+        use_compile = hasattr(opts, 'SWIN_torch_compile') and opts.SWIN_torch_compile \
+            and int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows"
+        current_config = (model_file, opts.SWIN_tile)
+
+        if use_compile and self._cached_model_config == current_config:
+            model = self._cached_model
+        else:
+            self._cached_model = None
+            try:
+                model = self.load_model(model_file)
+            except Exception as e:
+                print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
+                return img
+            model = model.to(device_swinir, dtype=devices.dtype)
+            if use_compile:
+                model = torch.compile(model)
+                self._cached_model = model
+                self._cached_model_config = current_config
         img = upscale(img, model)
         img = upscale(img, model)
-        try:
-            torch.cuda.empty_cache()
-        except Exception:
-            pass
+        devices.torch_gc()
         return img
         return img
 
 
     def load_model(self, path, scale=4):
     def load_model(self, path, scale=4):
-        if "http" in path:
-            dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth")
-            filename = load_file_from_url(url=path, model_dir=self.model_download_path, file_name=dl_name, progress=True)
+        if path.startswith("http"):
+            filename = modelloader.load_file_from_url(
+                url=path,
+                model_dir=self.model_download_path,
+                file_name=f"{self.model_name.replace(' ', '_')}.pth",
+            )
         else:
         else:
             filename = path
             filename = path
-        if filename is None or not os.path.exists(filename):
-            return None
         if filename.endswith(".v2.pth"):
         if filename.endswith(".v2.pth"):
-            model = net2(
-            upscale=scale,
-            in_chans=3,
-            img_size=64,
-            window_size=8,
-            img_range=1.0,
-            depths=[6, 6, 6, 6, 6, 6],
-            embed_dim=180,
-            num_heads=[6, 6, 6, 6, 6, 6],
-            mlp_ratio=2,
-            upsampler="nearest+conv",
-            resi_connection="1conv",
+            model = Swin2SR(
+                upscale=scale,
+                in_chans=3,
+                img_size=64,
+                window_size=8,
+                img_range=1.0,
+                depths=[6, 6, 6, 6, 6, 6],
+                embed_dim=180,
+                num_heads=[6, 6, 6, 6, 6, 6],
+                mlp_ratio=2,
+                upsampler="nearest+conv",
+                resi_connection="1conv",
             )
             )
             params = None
             params = None
         else:
         else:
-            model = net(
+            model = SwinIR(
                 upscale=scale,
                 upscale=scale,
                 in_chans=3,
                 in_chans=3,
                 img_size=64,
                 img_size=64,
@@ -172,6 +185,8 @@ def on_ui_settings():
 
 
     shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
     shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
     shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
     shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
+    if int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows":    # torch.compile() require pytorch 2.0 or above, and not on Windows
+        shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
 
 
 
 
 script_callbacks.on_ui_settings(on_ui_settings)
 script_callbacks.on_ui_settings(on_ui_settings)

+ 29 - 1
extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js

@@ -200,7 +200,8 @@ onUiLoaded(async() => {
         canvas_hotkey_move: "KeyF",
         canvas_hotkey_move: "KeyF",
         canvas_hotkey_overlap: "KeyO",
         canvas_hotkey_overlap: "KeyO",
         canvas_disabled_functions: [],
         canvas_disabled_functions: [],
-        canvas_show_tooltip: true
+        canvas_show_tooltip: true,
+        canvas_blur_prompt: false
     };
     };
 
 
     const functionMap = {
     const functionMap = {
@@ -608,6 +609,19 @@ onUiLoaded(async() => {
 
 
         // Handle keydown events
         // Handle keydown events
         function handleKeyDown(event) {
         function handleKeyDown(event) {
+            // Disable key locks to make pasting from the buffer work correctly
+            if ((event.ctrlKey && event.code === 'KeyV') || (event.ctrlKey && event.code === 'KeyC') || event.code === "F5") {
+                return;
+            }
+
+            // before activating shortcut, ensure user is not actively typing in an input field
+            if (!hotkeysConfig.canvas_blur_prompt) {
+                if (event.target.nodeName === 'TEXTAREA' || event.target.nodeName === 'INPUT') {
+                    return;
+                }
+            }
+
+
             const hotkeyActions = {
             const hotkeyActions = {
                 [hotkeysConfig.canvas_hotkey_reset]: resetZoom,
                 [hotkeysConfig.canvas_hotkey_reset]: resetZoom,
                 [hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap,
                 [hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap,
@@ -686,6 +700,20 @@ onUiLoaded(async() => {
 
 
         // Handle the move event for pan functionality. Updates the panX and panY variables and applies the new transform to the target element.
         // Handle the move event for pan functionality. Updates the panX and panY variables and applies the new transform to the target element.
         function handleMoveKeyDown(e) {
         function handleMoveKeyDown(e) {
+
+            // Disable key locks to make pasting from the buffer work correctly
+            if ((e.ctrlKey && e.code === 'KeyV') || (e.ctrlKey && event.code === 'KeyC') || e.code === "F5") {
+                return;
+            }
+
+            // before activating shortcut, ensure user is not actively typing in an input field
+            if (!hotkeysConfig.canvas_blur_prompt) {
+                if (e.target.nodeName === 'TEXTAREA' || e.target.nodeName === 'INPUT') {
+                    return;
+                }
+            }
+
+
             if (e.code === hotkeysConfig.canvas_hotkey_move) {
             if (e.code === hotkeysConfig.canvas_hotkey_move) {
                 if (!e.ctrlKey && !e.metaKey && isKeyDownHandlerAttached) {
                 if (!e.ctrlKey && !e.metaKey && isKeyDownHandlerAttached) {
                     e.preventDefault();
                     e.preventDefault();

+ 1 - 0
extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py

@@ -9,5 +9,6 @@ shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas
     "canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
     "canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
     "canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"),
     "canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"),
     "canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
     "canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
+    "canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"),
     "canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
     "canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
 }))
 }))

+ 3 - 2
javascript/edit-attention.js

@@ -100,11 +100,12 @@ function keyupEditAttention(event) {
     if (String(weight).length == 1) weight += ".0";
     if (String(weight).length == 1) weight += ".0";
 
 
     if (closeCharacter == ')' && weight == 1) {
     if (closeCharacter == ')' && weight == 1) {
-        text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5);
+        var endParenPos = text.substring(selectionEnd).indexOf(')');
+        text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + endParenPos + 1);
         selectionStart--;
         selectionStart--;
         selectionEnd--;
         selectionEnd--;
     } else {
     } else {
-        text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
+        text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + end);
     }
     }
 
 
     target.focus();
     target.focus();

+ 41 - 0
javascript/edit-order.js

@@ -0,0 +1,41 @@
+/* alt+left/right moves text in prompt */
+
+function keyupEditOrder(event) {
+    if (!opts.keyedit_move) return;
+
+    let target = event.originalTarget || event.composedPath()[0];
+    if (!target.matches("*:is([id*='_toprow'] [id*='_prompt'], .prompt) textarea")) return;
+    if (!event.altKey) return;
+
+    let isLeft = event.key == "ArrowLeft";
+    let isRight = event.key == "ArrowRight";
+    if (!isLeft && !isRight) return;
+    event.preventDefault();
+
+    let selectionStart = target.selectionStart;
+    let selectionEnd = target.selectionEnd;
+    let text = target.value;
+    let items = text.split(",");
+    let indexStart = (text.slice(0, selectionStart).match(/,/g) || []).length;
+    let indexEnd = (text.slice(0, selectionEnd).match(/,/g) || []).length;
+    let range = indexEnd - indexStart + 1;
+
+    if (isLeft && indexStart > 0) {
+        items.splice(indexStart - 1, 0, ...items.splice(indexStart, range));
+        target.value = items.join();
+        target.selectionStart = items.slice(0, indexStart - 1).join().length + (indexStart == 1 ? 0 : 1);
+        target.selectionEnd = items.slice(0, indexEnd).join().length;
+    } else if (isRight && indexEnd < items.length - 1) {
+        items.splice(indexStart + 1, 0, ...items.splice(indexStart, range));
+        target.value = items.join();
+        target.selectionStart = items.slice(0, indexStart + 1).join().length + 1;
+        target.selectionEnd = items.slice(0, indexEnd + 2).join().length;
+    }
+
+    event.preventDefault();
+    updateInput(target);
+}
+
+addEventListener('keydown', (event) => {
+    keyupEditOrder(event);
+});

+ 18 - 0
javascript/extensions.js

@@ -72,3 +72,21 @@ function config_state_confirm_restore(_, config_state_name, config_restore_type)
     }
     }
     return [confirmed, config_state_name, config_restore_type];
     return [confirmed, config_state_name, config_restore_type];
 }
 }
+
+function toggle_all_extensions(event) {
+    gradioApp().querySelectorAll('#extensions .extension_toggle').forEach(function(checkbox_el) {
+        checkbox_el.checked = event.target.checked;
+    });
+}
+
+function toggle_extension() {
+    let all_extensions_toggled = true;
+    for (const checkbox_el of gradioApp().querySelectorAll('#extensions .extension_toggle')) {
+        if (!checkbox_el.checked) {
+            all_extensions_toggled = false;
+            break;
+        }
+    }
+
+    gradioApp().querySelector('#extensions .all_extensions_toggle').checked = all_extensions_toggled;
+}

+ 73 - 56
modules/api/api.py

@@ -14,7 +14,7 @@ from fastapi.encoders import jsonable_encoder
 from secrets import compare_digest
 from secrets import compare_digest
 
 
 import modules.shared as shared
 import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart
 from modules.api import models
 from modules.api import models
 from modules.shared import opts
 from modules.shared import opts
 from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
 from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
@@ -22,7 +22,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
 from modules.textual_inversion.preprocess import preprocess
 from modules.textual_inversion.preprocess import preprocess
 from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
 from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
 from PIL import PngImagePlugin,Image
 from PIL import PngImagePlugin,Image
-from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
+from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_aliases
 from modules.sd_vae import vae_dict
 from modules.sd_vae import vae_dict
 from modules.sd_models_config import find_checkpoint_config_near_filename
 from modules.sd_models_config import find_checkpoint_config_near_filename
 from modules.realesrgan_model import get_realesrgan_models
 from modules.realesrgan_model import get_realesrgan_models
@@ -30,13 +30,7 @@ from modules import devices
 from typing import Dict, List, Any
 from typing import Dict, List, Any
 import piexif
 import piexif
 import piexif.helper
 import piexif.helper
-
-
-def upscaler_to_index(name: str):
-    try:
-        return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
-    except Exception as e:
-        raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e
+from contextlib import closing
 
 
 
 
 def script_name_to_index(name, scripts):
 def script_name_to_index(name, scripts):
@@ -84,6 +78,8 @@ def encode_pil_to_base64(image):
             image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
             image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
 
 
         elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
         elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
+            if image.mode == "RGBA":
+                image = image.convert("RGB")
             parameters = image.info.get('parameters', None)
             parameters = image.info.get('parameters', None)
             exif_bytes = piexif.dump({
             exif_bytes = piexif.dump({
                 "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
                 "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
@@ -209,6 +205,11 @@ class Api:
         self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
         self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
         self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
         self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
 
 
+        if shared.cmd_opts.api_server_stop:
+            self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
+            self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"])
+            self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"])
+
         self.default_script_arg_txt2img = []
         self.default_script_arg_txt2img = []
         self.default_script_arg_img2img = []
         self.default_script_arg_img2img = []
 
 
@@ -324,19 +325,19 @@ class Api:
         args.pop('save_images', None)
         args.pop('save_images', None)
 
 
         with self.queue_lock:
         with self.queue_lock:
-            p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
-            p.scripts = script_runner
-            p.outpath_grids = opts.outdir_txt2img_grids
-            p.outpath_samples = opts.outdir_txt2img_samples
-
-            shared.state.begin()
-            if selectable_scripts is not None:
-                p.script_args = script_args
-                processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
-            else:
-                p.script_args = tuple(script_args) # Need to pass args as tuple here
-                processed = process_images(p)
-            shared.state.end()
+            with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
+                p.scripts = script_runner
+                p.outpath_grids = opts.outdir_txt2img_grids
+                p.outpath_samples = opts.outdir_txt2img_samples
+
+                shared.state.begin(job="scripts_txt2img")
+                if selectable_scripts is not None:
+                    p.script_args = script_args
+                    processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
+                else:
+                    p.script_args = tuple(script_args) # Need to pass args as tuple here
+                    processed = process_images(p)
+                shared.state.end()
 
 
         b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
         b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
 
 
@@ -380,20 +381,20 @@ class Api:
         args.pop('save_images', None)
         args.pop('save_images', None)
 
 
         with self.queue_lock:
         with self.queue_lock:
-            p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
-            p.init_images = [decode_base64_to_image(x) for x in init_images]
-            p.scripts = script_runner
-            p.outpath_grids = opts.outdir_img2img_grids
-            p.outpath_samples = opts.outdir_img2img_samples
-
-            shared.state.begin()
-            if selectable_scripts is not None:
-                p.script_args = script_args
-                processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
-            else:
-                p.script_args = tuple(script_args) # Need to pass args as tuple here
-                processed = process_images(p)
-            shared.state.end()
+            with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
+                p.init_images = [decode_base64_to_image(x) for x in init_images]
+                p.scripts = script_runner
+                p.outpath_grids = opts.outdir_img2img_grids
+                p.outpath_samples = opts.outdir_img2img_samples
+
+                shared.state.begin(job="scripts_img2img")
+                if selectable_scripts is not None:
+                    p.script_args = script_args
+                    processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
+                else:
+                    p.script_args = tuple(script_args) # Need to pass args as tuple here
+                    processed = process_images(p)
+                shared.state.end()
 
 
         b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
         b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
 
 
@@ -517,6 +518,10 @@ class Api:
         return options
         return options
 
 
     def set_config(self, req: Dict[str, Any]):
     def set_config(self, req: Dict[str, Any]):
+        checkpoint_name = req.get("sd_model_checkpoint", None)
+        if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases:
+            raise RuntimeError(f"model {checkpoint_name!r} not found")
+
         for k, v in req.items():
         for k, v in req.items():
             shared.opts.set(k, v)
             shared.opts.set(k, v)
 
 
@@ -598,44 +603,42 @@ class Api:
 
 
     def create_embedding(self, args: dict):
     def create_embedding(self, args: dict):
         try:
         try:
-            shared.state.begin()
+            shared.state.begin(job="create_embedding")
             filename = create_embedding(**args) # create empty embedding
             filename = create_embedding(**args) # create empty embedding
             sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
             sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
-            shared.state.end()
             return models.CreateResponse(info=f"create embedding filename: {filename}")
             return models.CreateResponse(info=f"create embedding filename: {filename}")
         except AssertionError as e:
         except AssertionError as e:
-            shared.state.end()
             return models.TrainResponse(info=f"create embedding error: {e}")
             return models.TrainResponse(info=f"create embedding error: {e}")
+        finally:
+            shared.state.end()
+
 
 
     def create_hypernetwork(self, args: dict):
     def create_hypernetwork(self, args: dict):
         try:
         try:
-            shared.state.begin()
+            shared.state.begin(job="create_hypernetwork")
             filename = create_hypernetwork(**args) # create empty embedding
             filename = create_hypernetwork(**args) # create empty embedding
-            shared.state.end()
             return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
             return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
         except AssertionError as e:
         except AssertionError as e:
-            shared.state.end()
             return models.TrainResponse(info=f"create hypernetwork error: {e}")
             return models.TrainResponse(info=f"create hypernetwork error: {e}")
+        finally:
+            shared.state.end()
 
 
     def preprocess(self, args: dict):
     def preprocess(self, args: dict):
         try:
         try:
-            shared.state.begin()
+            shared.state.begin(job="preprocess")
             preprocess(**args) # quick operation unless blip/booru interrogation is enabled
             preprocess(**args) # quick operation unless blip/booru interrogation is enabled
             shared.state.end()
             shared.state.end()
-            return models.PreprocessResponse(info = 'preprocess complete')
+            return models.PreprocessResponse(info='preprocess complete')
         except KeyError as e:
         except KeyError as e:
-            shared.state.end()
             return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
             return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
-        except AssertionError as e:
-            shared.state.end()
+        except Exception as e:
             return models.PreprocessResponse(info=f"preprocess error: {e}")
             return models.PreprocessResponse(info=f"preprocess error: {e}")
-        except FileNotFoundError as e:
+        finally:
             shared.state.end()
             shared.state.end()
-            return models.PreprocessResponse(info=f'preprocess error: {e}')
 
 
     def train_embedding(self, args: dict):
     def train_embedding(self, args: dict):
         try:
         try:
-            shared.state.begin()
+            shared.state.begin(job="train_embedding")
             apply_optimizations = shared.opts.training_xattention_optimizations
             apply_optimizations = shared.opts.training_xattention_optimizations
             error = None
             error = None
             filename = ''
             filename = ''
@@ -648,15 +651,15 @@ class Api:
             finally:
             finally:
                 if not apply_optimizations:
                 if not apply_optimizations:
                     sd_hijack.apply_optimizations()
                     sd_hijack.apply_optimizations()
-                shared.state.end()
             return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
             return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
-        except AssertionError as msg:
-            shared.state.end()
+        except Exception as msg:
             return models.TrainResponse(info=f"train embedding error: {msg}")
             return models.TrainResponse(info=f"train embedding error: {msg}")
+        finally:
+            shared.state.end()
 
 
     def train_hypernetwork(self, args: dict):
     def train_hypernetwork(self, args: dict):
         try:
         try:
-            shared.state.begin()
+            shared.state.begin(job="train_hypernetwork")
             shared.loaded_hypernetworks = []
             shared.loaded_hypernetworks = []
             apply_optimizations = shared.opts.training_xattention_optimizations
             apply_optimizations = shared.opts.training_xattention_optimizations
             error = None
             error = None
@@ -674,9 +677,10 @@ class Api:
                     sd_hijack.apply_optimizations()
                     sd_hijack.apply_optimizations()
                 shared.state.end()
                 shared.state.end()
             return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
             return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
-        except AssertionError:
+        except Exception as exc:
+            return models.TrainResponse(info=f"train embedding error: {exc}")
+        finally:
             shared.state.end()
             shared.state.end()
-            return models.TrainResponse(info=f"train embedding error: {error}")
 
 
     def get_memory(self):
     def get_memory(self):
         try:
         try:
@@ -716,3 +720,16 @@ class Api:
     def launch(self, server_name, port):
     def launch(self, server_name, port):
         self.app.include_router(self.router)
         self.app.include_router(self.router)
         uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive)
         uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive)
+
+    def kill_webui(self):
+        restart.stop_program()
+
+    def restart_webui(self):
+        if restart.is_restartable():
+            restart.restart_program()
+        return Response(status_code=501)
+
+    def stop_webui(request):
+        shared.state.server_command = "stop"
+        return Response("Stopping.")
+

+ 0 - 4
modules/api/models.py

@@ -274,10 +274,6 @@ class PromptStyleItem(BaseModel):
     prompt: Optional[str] = Field(title="Prompt")
     prompt: Optional[str] = Field(title="Prompt")
     negative_prompt: Optional[str] = Field(title="Negative Prompt")
     negative_prompt: Optional[str] = Field(title="Negative Prompt")
 
 
-class ArtistItem(BaseModel):
-    name: str = Field(title="Name")
-    score: float = Field(title="Score")
-    category: str = Field(title="Category")
 
 
 class EmbeddingItem(BaseModel):
 class EmbeddingItem(BaseModel):
     step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
     step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")

+ 4 - 1
modules/call_queue.py

@@ -1,3 +1,4 @@
+from functools import wraps
 import html
 import html
 import threading
 import threading
 import time
 import time
@@ -18,6 +19,7 @@ def wrap_queued_call(func):
 
 
 
 
 def wrap_gradio_gpu_call(func, extra_outputs=None):
 def wrap_gradio_gpu_call(func, extra_outputs=None):
+    @wraps(func)
     def f(*args, **kwargs):
     def f(*args, **kwargs):
 
 
         # if the first argument is a string that says "task(...)", it is treated as a job id
         # if the first argument is a string that says "task(...)", it is treated as a job id
@@ -28,7 +30,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
             id_task = None
             id_task = None
 
 
         with queue_lock:
         with queue_lock:
-            shared.state.begin()
+            shared.state.begin(job=id_task)
             progress.start_task(id_task)
             progress.start_task(id_task)
 
 
             try:
             try:
@@ -45,6 +47,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
 
 
 
 
 def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
 def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
+    @wraps(func)
     def f(*args, extra_outputs_array=extra_outputs, **kwargs):
     def f(*args, extra_outputs_array=extra_outputs, **kwargs):
         run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
         run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
         if run_memmon:
         if run_memmon:

+ 1 - 0
modules/cmd_args.py

@@ -107,4 +107,5 @@ parser.add_argument("--no-hashing", action='store_true', help="disable sha256 ha
 parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
 parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
 parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
 parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
 parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
 parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
+parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
 parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
 parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')

+ 1 - 5
modules/codeformer_model.py

@@ -15,7 +15,6 @@ model_dir = "Codeformer"
 model_path = os.path.join(models_path, model_dir)
 model_path = os.path.join(models_path, model_dir)
 model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
 model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
 
 
-have_codeformer = False
 codeformer = None
 codeformer = None
 
 
 
 
@@ -100,7 +99,7 @@ def setup_model(dirname):
                             output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
                             output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
                             restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
                             restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
                         del output
                         del output
-                        torch.cuda.empty_cache()
+                        devices.torch_gc()
                     except Exception:
                     except Exception:
                         errors.report('Failed inference for CodeFormer', exc_info=True)
                         errors.report('Failed inference for CodeFormer', exc_info=True)
                         restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
                         restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
@@ -123,9 +122,6 @@ def setup_model(dirname):
 
 
                 return restored_img
                 return restored_img
 
 
-        global have_codeformer
-        have_codeformer = True
-
         global codeformer
         global codeformer
         codeformer = FaceRestorerCodeFormer(dirname)
         codeformer = FaceRestorerCodeFormer(dirname)
         shared.face_restorers.append(codeformer)
         shared.face_restorers.append(codeformer)

+ 4 - 7
modules/devices.py

@@ -15,13 +15,6 @@ def has_mps() -> bool:
     else:
     else:
         return mac_specific.has_mps
         return mac_specific.has_mps
 
 
-def extract_device_id(args, name):
-    for x in range(len(args)):
-        if name in args[x]:
-            return args[x + 1]
-
-    return None
-
 
 
 def get_cuda_device_string():
 def get_cuda_device_string():
     from modules import shared
     from modules import shared
@@ -56,11 +49,15 @@ def get_device_for(task):
 
 
 
 
 def torch_gc():
 def torch_gc():
+
     if torch.cuda.is_available():
     if torch.cuda.is_available():
         with torch.cuda.device(get_cuda_device_string()):
         with torch.cuda.device(get_cuda_device_string()):
             torch.cuda.empty_cache()
             torch.cuda.empty_cache()
             torch.cuda.ipc_collect()
             torch.cuda.ipc_collect()
 
 
+    if has_mps():
+        mac_specific.torch_mps_gc()
+
 
 
 def enable_tf32():
 def enable_tf32():
     if torch.cuda.is_available():
     if torch.cuda.is_available():

+ 10 - 13
modules/esrgan_model.py

@@ -1,15 +1,13 @@
-import os
+import sys
 
 
 import numpy as np
 import numpy as np
 import torch
 import torch
 from PIL import Image
 from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
 
 
 import modules.esrgan_model_arch as arch
 import modules.esrgan_model_arch as arch
 from modules import modelloader, images, devices
 from modules import modelloader, images, devices
-from modules.upscaler import Upscaler, UpscalerData
 from modules.shared import opts
 from modules.shared import opts
-
+from modules.upscaler import Upscaler, UpscalerData
 
 
 
 
 def mod2normal(state_dict):
 def mod2normal(state_dict):
@@ -134,7 +132,7 @@ class UpscalerESRGAN(Upscaler):
             scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
             scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
             scalers.append(scaler_data)
             scalers.append(scaler_data)
         for file in model_paths:
         for file in model_paths:
-            if "http" in file:
+            if file.startswith("http"):
                 name = self.model_name
                 name = self.model_name
             else:
             else:
                 name = modelloader.friendly_name(file)
                 name = modelloader.friendly_name(file)
@@ -143,26 +141,25 @@ class UpscalerESRGAN(Upscaler):
             self.scalers.append(scaler_data)
             self.scalers.append(scaler_data)
 
 
     def do_upscale(self, img, selected_model):
     def do_upscale(self, img, selected_model):
-        model = self.load_model(selected_model)
-        if model is None:
+        try:
+            model = self.load_model(selected_model)
+        except Exception as e:
+            print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
             return img
             return img
         model.to(devices.device_esrgan)
         model.to(devices.device_esrgan)
         img = esrgan_upscale(model, img)
         img = esrgan_upscale(model, img)
         return img
         return img
 
 
     def load_model(self, path: str):
     def load_model(self, path: str):
-        if "http" in path:
-            filename = load_file_from_url(
+        if path.startswith("http"):
+            # TODO: this doesn't use `path` at all?
+            filename = modelloader.load_file_from_url(
                 url=self.model_url,
                 url=self.model_url,
                 model_dir=self.model_download_path,
                 model_dir=self.model_download_path,
                 file_name=f"{self.model_name}.pth",
                 file_name=f"{self.model_name}.pth",
-                progress=True,
             )
             )
         else:
         else:
             filename = path
             filename = path
-        if not os.path.exists(filename) or filename is None:
-            print(f"Unable to load {self.model_path} from {filename}")
-            return None
 
 
         state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
         state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
 
 

+ 3 - 0
modules/extra_networks.py

@@ -103,6 +103,9 @@ def activate(p, extra_network_data):
         except Exception as e:
         except Exception as e:
             errors.display(e, f"activating extra network {extra_network_name}")
             errors.display(e, f"activating extra network {extra_network_name}")
 
 
+    if p.scripts is not None:
+        p.scripts.after_extra_networks_activate(p, batch_number=p.iteration, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds, extra_network_data=extra_network_data)
+
 
 
 def deactivate(p, extra_network_data):
 def deactivate(p, extra_network_data):
     """call deactivate for extra networks in extra_network_data in specified order, then call
     """call deactivate for extra networks in extra_network_data in specified order, then call

+ 1 - 2
modules/extras.py

@@ -73,8 +73,7 @@ def to_half(tensor, enable):
 
 
 
 
 def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
 def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
-    shared.state.begin()
-    shared.state.job = 'model-merge'
+    shared.state.begin(job="model-merge")
 
 
     def fail(message):
     def fail(message):
         shared.state.textinfo = message
         shared.state.textinfo = message

+ 0 - 29
modules/generation_parameters_copypaste.py

@@ -174,31 +174,6 @@ def send_image_and_dimensions(x):
     return img, w, h
     return img, w, h
 
 
 
 
-
-def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
-    """Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
-
-    Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
-    parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
-
-    If the infotext has no hash, then a hypernet with the same name will be selected instead.
-    """
-    hypernet_name = hypernet_name.lower()
-    if hypernet_hash is not None:
-        # Try to match the hash in the name
-        for hypernet_key in shared.hypernetworks.keys():
-            result = re_hypernet_hash.search(hypernet_key)
-            if result is not None and result[1] == hypernet_hash:
-                return hypernet_key
-    else:
-        # Fall back to a hypernet with the same name
-        for hypernet_key in shared.hypernetworks.keys():
-            if hypernet_key.lower().startswith(hypernet_name):
-                return hypernet_key
-
-    return None
-
-
 def restore_old_hires_fix_params(res):
 def restore_old_hires_fix_params(res):
     """for infotexts that specify old First pass size parameter, convert it into
     """for infotexts that specify old First pass size parameter, convert it into
     width, height, and hr scale"""
     width, height, and hr scale"""
@@ -332,10 +307,6 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
     return res
     return res
 
 
 
 
-settings_map = {}
-
-
-
 infotext_to_setting_name_mapping = [
 infotext_to_setting_name_mapping = [
     ('Clip skip', 'CLIP_stop_at_last_layers', ),
     ('Clip skip', 'CLIP_stop_at_last_layers', ),
     ('Conditional mask weight', 'inpainting_mask_weight'),
     ('Conditional mask weight', 'inpainting_mask_weight'),

+ 1 - 1
modules/gfpgan_model.py

@@ -25,7 +25,7 @@ def gfpgann():
         return None
         return None
 
 
     models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
     models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
-    if len(models) == 1 and "http" in models[0]:
+    if len(models) == 1 and models[0].startswith("http"):
         model_file = models[0]
         model_file = models[0]
     elif len(models) != 0:
     elif len(models) != 0:
         latest_file = max(models, key=os.path.getctime)
         latest_file = max(models, key=os.path.getctime)

+ 4 - 26
modules/hypernetworks/hypernetwork.py

@@ -3,6 +3,7 @@ import glob
 import html
 import html
 import os
 import os
 import inspect
 import inspect
+from contextlib import closing
 
 
 import modules.textual_inversion.dataset
 import modules.textual_inversion.dataset
 import torch
 import torch
@@ -353,17 +354,6 @@ def load_hypernetworks(names, multipliers=None):
         shared.loaded_hypernetworks.append(hypernetwork)
         shared.loaded_hypernetworks.append(hypernetwork)
 
 
 
 
-def find_closest_hypernetwork_name(search: str):
-    if not search:
-        return None
-    search = search.lower()
-    applicable = [name for name in shared.hypernetworks if search in name.lower()]
-    if not applicable:
-        return None
-    applicable = sorted(applicable, key=lambda name: len(name))
-    return applicable[0]
-
-
 def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
 def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
     hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
     hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
 
 
@@ -446,18 +436,6 @@ def statistics(data):
     return total_information, recent_information
     return total_information, recent_information
 
 
 
 
-def report_statistics(loss_info:dict):
-    keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
-    for key in keys:
-        try:
-            print("Loss statistics for file " + key)
-            info, recent = statistics(list(loss_info[key]))
-            print(info)
-            print(recent)
-        except Exception as e:
-            print(e)
-
-
 def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
 def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
     # Remove illegal characters from name.
     # Remove illegal characters from name.
     name = "".join( x for x in name if (x.isalnum() or x in "._- "))
     name = "".join( x for x in name if (x.isalnum() or x in "._- "))
@@ -734,8 +712,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
 
 
                     preview_text = p.prompt
                     preview_text = p.prompt
 
 
-                    processed = processing.process_images(p)
-                    image = processed.images[0] if len(processed.images) > 0 else None
+                    with closing(p):
+                        processed = processing.process_images(p)
+                        image = processed.images[0] if len(processed.images) > 0 else None
 
 
                     if unload:
                     if unload:
                         shared.sd_model.cond_stage_model.to(devices.cpu)
                         shared.sd_model.cond_stage_model.to(devices.cpu)
@@ -770,7 +749,6 @@ Last saved image: {html.escape(last_saved_image)}<br/>
         pbar.leave = False
         pbar.leave = False
         pbar.close()
         pbar.close()
         hypernetwork.eval()
         hypernetwork.eval()
-        #report_statistics(loss_dict)
         sd_hijack_checkpoint.remove()
         sd_hijack_checkpoint.remove()
 
 
 
 

+ 48 - 21
modules/images.py

@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import datetime
 import datetime
 
 
 import pytz
 import pytz
@@ -10,7 +12,7 @@ import re
 import numpy as np
 import numpy as np
 import piexif
 import piexif
 import piexif.helper
 import piexif.helper
-from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
+from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin
 import string
 import string
 import json
 import json
 import hashlib
 import hashlib
@@ -139,6 +141,11 @@ class GridAnnotation:
 
 
 
 
 def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
 def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
+
+    color_active = ImageColor.getcolor(opts.grid_text_active_color, 'RGB')
+    color_inactive = ImageColor.getcolor(opts.grid_text_inactive_color, 'RGB')
+    color_background = ImageColor.getcolor(opts.grid_background_color, 'RGB')
+
     def wrap(drawing, text, font, line_length):
     def wrap(drawing, text, font, line_length):
         lines = ['']
         lines = ['']
         for word in text.split():
         for word in text.split():
@@ -168,9 +175,6 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
 
 
     fnt = get_font(fontsize)
     fnt = get_font(fontsize)
 
 
-    color_active = (0, 0, 0)
-    color_inactive = (153, 153, 153)
-
     pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
     pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
 
 
     cols = im.width // width
     cols = im.width // width
@@ -179,7 +183,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
     assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
     assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
     assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
     assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
 
 
-    calc_img = Image.new("RGB", (1, 1), "white")
+    calc_img = Image.new("RGB", (1, 1), color_background)
     calc_d = ImageDraw.Draw(calc_img)
     calc_d = ImageDraw.Draw(calc_img)
 
 
     for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
     for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
@@ -200,7 +204,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
 
 
     pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
     pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
 
 
-    result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), "white")
+    result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), color_background)
 
 
     for row in range(rows):
     for row in range(rows):
         for col in range(cols):
         for col in range(cols):
@@ -302,12 +306,14 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None):
 
 
         if ratio < src_ratio:
         if ratio < src_ratio:
             fill_height = height // 2 - src_h // 2
             fill_height = height // 2 - src_h // 2
-            res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
-            res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
+            if fill_height > 0:
+                res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
+                res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
         elif ratio > src_ratio:
         elif ratio > src_ratio:
             fill_width = width // 2 - src_w // 2
             fill_width = width // 2 - src_w // 2
-            res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
-            res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
+            if fill_width > 0:
+                res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
+                res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
 
 
     return res
     return res
 
 
@@ -372,8 +378,8 @@ class FilenameGenerator:
         'hasprompt': lambda self, *args: self.hasprompt(*args),  # accepts formats:[hasprompt<prompt1|default><prompt2>..]
         'hasprompt': lambda self, *args: self.hasprompt(*args),  # accepts formats:[hasprompt<prompt1|default><prompt2>..]
         'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
         'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
         'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
         'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
+        'user': lambda self: self.p.user,
         'vae_filename': lambda self: self.get_vae_filename(),
         'vae_filename': lambda self: self.get_vae_filename(),
-
     }
     }
     default_time_format = '%Y%m%d%H%M%S'
     default_time_format = '%Y%m%d%H%M%S'
 
 
@@ -497,13 +503,23 @@ def get_next_sequence_number(path, basename):
     return result + 1
     return result + 1
 
 
 
 
-def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None):
+def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None, pnginfo_section_name='parameters'):
+    """
+    Saves image to filename, including geninfo as text information for generation info.
+    For PNG images, geninfo is added to existing pnginfo dictionary using the pnginfo_section_name argument as key.
+    For JPG images, there's no dictionary and geninfo just replaces the EXIF description.
+    """
+
     if extension is None:
     if extension is None:
         extension = os.path.splitext(filename)[1]
         extension = os.path.splitext(filename)[1]
 
 
     image_format = Image.registered_extensions()[extension]
     image_format = Image.registered_extensions()[extension]
 
 
     if extension.lower() == '.png':
     if extension.lower() == '.png':
+        existing_pnginfo = existing_pnginfo or {}
+        if opts.enable_pnginfo:
+            existing_pnginfo[pnginfo_section_name] = geninfo
+
         if opts.enable_pnginfo:
         if opts.enable_pnginfo:
             pnginfo_data = PngImagePlugin.PngInfo()
             pnginfo_data = PngImagePlugin.PngInfo()
             for k, v in (existing_pnginfo or {}).items():
             for k, v in (existing_pnginfo or {}).items():
@@ -622,7 +638,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
         """
         """
         temp_file_path = f"{filename_without_extension}.tmp"
         temp_file_path = f"{filename_without_extension}.tmp"
 
 
-        save_image_with_geninfo(image_to_save, info, temp_file_path, extension, params.pnginfo)
+        save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name)
 
 
         os.replace(temp_file_path, filename_without_extension + extension)
         os.replace(temp_file_path, filename_without_extension + extension)
 
 
@@ -639,12 +655,18 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
     oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
     oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
     if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
     if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
         ratio = image.width / image.height
         ratio = image.width / image.height
-
+        resize_to = None
         if oversize and ratio > 1:
         if oversize and ratio > 1:
-            image = image.resize((round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)), LANCZOS)
+            resize_to = round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)
         elif oversize:
         elif oversize:
-            image = image.resize((round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)), LANCZOS)
+            resize_to = round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)
 
 
+        if resize_to is not None:
+            try:
+                # Resizing image with LANCZOS could throw an exception if e.g. image mode is I;16
+                image = image.resize(resize_to, LANCZOS)
+            except Exception:
+                image = image.resize(resize_to)
         try:
         try:
             _atomically_save_image(image, fullfn_without_extension, ".jpg")
             _atomically_save_image(image, fullfn_without_extension, ".jpg")
         except Exception as e:
         except Exception as e:
@@ -662,8 +684,15 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
     return fullfn, txt_fullfn
     return fullfn, txt_fullfn
 
 
 
 
-def read_info_from_image(image):
-    items = image.info or {}
+IGNORED_INFO_KEYS = {
+    'jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
+    'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
+    'icc_profile', 'chromaticity', 'photoshop',
+}
+
+
+def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
+    items = (image.info or {}).copy()
 
 
     geninfo = items.pop('parameters', None)
     geninfo = items.pop('parameters', None)
 
 
@@ -679,9 +708,7 @@ def read_info_from_image(image):
             items['exif comment'] = exif_comment
             items['exif comment'] = exif_comment
             geninfo = exif_comment
             geninfo = exif_comment
 
 
-    for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
-                    'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
-                    'icc_profile', 'chromaticity']:
+    for field in IGNORED_INFO_KEYS:
         items.pop(field, None)
         items.pop(field, None)
 
 
     if items.get("Software", None) == "NovelAI":
     if items.get("Software", None) == "NovelAI":

+ 52 - 20
modules/img2img.py

@@ -1,23 +1,26 @@
 import os
 import os
+from contextlib import closing
 from pathlib import Path
 from pathlib import Path
 
 
 import numpy as np
 import numpy as np
 from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
 from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
+import gradio as gr
 
 
-from modules import sd_samplers
-from modules.generation_parameters_copypaste import create_override_settings_dict
+from modules import sd_samplers, images as imgutil
+from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
 from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
 from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
 from modules.shared import opts, state
 from modules.shared import opts, state
+from modules.images import save_image
 import modules.shared as shared
 import modules.shared as shared
 import modules.processing as processing
 import modules.processing as processing
 from modules.ui import plaintext_to_html
 from modules.ui import plaintext_to_html
 import modules.scripts
 import modules.scripts
 
 
 
 
-def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0):
+def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
     processing.fix_seed(p)
     processing.fix_seed(p)
 
 
-    images = shared.listfiles(input_dir)
+    images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp")))
 
 
     is_inpaint_batch = False
     is_inpaint_batch = False
     if inpaint_mask_dir:
     if inpaint_mask_dir:
@@ -36,6 +39,14 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
 
 
     state.job_count = len(images) * p.n_iter
     state.job_count = len(images) * p.n_iter
 
 
+    # extract "default" params to use in case getting png info fails
+    prompt = p.prompt
+    negative_prompt = p.negative_prompt
+    seed = p.seed
+    cfg_scale = p.cfg_scale
+    sampler_name = p.sampler_name
+    steps = p.steps
+
     for i, image in enumerate(images):
     for i, image in enumerate(images):
         state.job = f"{i+1} out of {len(images)}"
         state.job = f"{i+1} out of {len(images)}"
         if state.skipped:
         if state.skipped:
@@ -79,25 +90,45 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
             mask_image = Image.open(mask_image_path)
             mask_image = Image.open(mask_image_path)
             p.image_mask = mask_image
             p.image_mask = mask_image
 
 
+        if use_png_info:
+            try:
+                info_img = img
+                if png_info_dir:
+                    info_img_path = os.path.join(png_info_dir, os.path.basename(image))
+                    info_img = Image.open(info_img_path)
+                geninfo, _ = imgutil.read_info_from_image(info_img)
+                parsed_parameters = parse_generation_parameters(geninfo)
+                parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})}
+            except Exception:
+                parsed_parameters = {}
+
+            p.prompt = prompt + (" " + parsed_parameters["Prompt"] if "Prompt" in parsed_parameters else "")
+            p.negative_prompt = negative_prompt + (" " + parsed_parameters["Negative prompt"] if "Negative prompt" in parsed_parameters else "")
+            p.seed = int(parsed_parameters.get("Seed", seed))
+            p.cfg_scale = float(parsed_parameters.get("CFG scale", cfg_scale))
+            p.sampler_name = parsed_parameters.get("Sampler", sampler_name)
+            p.steps = int(parsed_parameters.get("Steps", steps))
+
         proc = modules.scripts.scripts_img2img.run(p, *args)
         proc = modules.scripts.scripts_img2img.run(p, *args)
         if proc is None:
         if proc is None:
             proc = process_images(p)
             proc = process_images(p)
 
 
         for n, processed_image in enumerate(proc.images):
         for n, processed_image in enumerate(proc.images):
-            filename = image_path.name
+            filename = image_path.stem
+            infotext = proc.infotext(p, n)
+            relpath = os.path.dirname(os.path.relpath(image, input_dir))
 
 
             if n > 0:
             if n > 0:
-                left, right = os.path.splitext(filename)
-                filename = f"{left}-{n}{right}"
+                filename += f"-{n}"
 
 
             if not save_normally:
             if not save_normally:
-                os.makedirs(output_dir, exist_ok=True)
+                os.makedirs(os.path.join(output_dir, relpath), exist_ok=True)
                 if processed_image.mode == 'RGBA':
                 if processed_image.mode == 'RGBA':
                     processed_image = processed_image.convert("RGB")
                     processed_image = processed_image.convert("RGB")
-                processed_image.save(os.path.join(output_dir, filename))
+                save_image(processed_image, os.path.join(output_dir, relpath), None, extension=opts.samples_format, info=infotext, forced_filename=filename, save_to_dirs=False)
 
 
 
 
-def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
+def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
     override_settings = create_override_settings_dict(override_settings_texts)
     override_settings = create_override_settings_dict(override_settings_texts)
 
 
     is_batch = mode == 5
     is_batch = mode == 5
@@ -180,24 +211,25 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
     p.scripts = modules.scripts.scripts_img2img
     p.scripts = modules.scripts.scripts_img2img
     p.script_args = args
     p.script_args = args
 
 
+    p.user = request.username
+
     if shared.cmd_opts.enable_console_prompts:
     if shared.cmd_opts.enable_console_prompts:
         print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
         print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
 
 
     if mask:
     if mask:
         p.extra_generation_params["Mask blur"] = mask_blur
         p.extra_generation_params["Mask blur"] = mask_blur
 
 
-    if is_batch:
-        assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
+    with closing(p):
+        if is_batch:
+            assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
 
 
-        process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by)
-
-        processed = Processed(p, [], p.seed, "")
-    else:
-        processed = modules.scripts.scripts_img2img.run(p, *args)
-        if processed is None:
-            processed = process_images(p)
+            process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
 
 
-    p.close()
+            processed = Processed(p, [], p.seed, "")
+        else:
+            processed = modules.scripts.scripts_img2img.run(p, *args)
+            if processed is None:
+                processed = process_images(p)
 
 
     shared.total_tqdm.clear()
     shared.total_tqdm.clear()
 
 

+ 1 - 2
modules/interrogate.py

@@ -184,8 +184,7 @@ class InterrogateModels:
 
 
     def interrogate(self, pil_image):
     def interrogate(self, pil_image):
         res = ""
         res = ""
-        shared.state.begin()
-        shared.state.job = 'interrogate'
+        shared.state.begin(job="interrogate")
         try:
         try:
             if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
             if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
                 lowvram.send_everything_to_cpu()
                 lowvram.send_everything_to_cpu()

+ 3 - 3
modules/launch_utils.py

@@ -142,15 +142,15 @@ def git_clone(url, dir, name, commithash=None):
         if commithash is None:
         if commithash is None:
             return
             return
 
 
-        current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
+        current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
         if current_hash == commithash:
         if current_hash == commithash:
             return
             return
 
 
         run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
         run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
-        run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
+        run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
         return
         return
 
 
-    run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
+    run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
 
 
     if commithash is not None:
     if commithash is not None:
         run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
         run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")

+ 31 - 8
modules/mac_specific.py

@@ -1,20 +1,43 @@
+import logging
+
 import torch
 import torch
 import platform
 import platform
 from modules.sd_hijack_utils import CondFunc
 from modules.sd_hijack_utils import CondFunc
 from packaging import version
 from packaging import version
 
 
+log = logging.getLogger(__name__)
+
 
 
-# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
-# check `getattr` and try it for compatibility
+# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
+# use check `getattr` and try it for compatibility.
+# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availabilty,
+# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
 def check_for_mps() -> bool:
 def check_for_mps() -> bool:
-    if not getattr(torch, 'has_mps', False):
-        return False
+    if version.parse(torch.__version__) <= version.parse("2.0.1"):
+        if not getattr(torch, 'has_mps', False):
+            return False
+        try:
+            torch.zeros(1).to(torch.device("mps"))
+            return True
+        except Exception:
+            return False
+    else:
+        return torch.backends.mps.is_available() and torch.backends.mps.is_built()
+
+
+has_mps = check_for_mps()
+
+
+def torch_mps_gc() -> None:
     try:
     try:
-        torch.zeros(1).to(torch.device("mps"))
-        return True
+        from modules.shared import state
+        if state.current_latent is not None:
+            log.debug("`current_latent` is set, skipping MPS garbage collection")
+            return
+        from torch.mps import empty_cache
+        empty_cache()
     except Exception:
     except Exception:
-        return False
-has_mps = check_for_mps()
+        log.warning("MPS garbage collection failed", exc_info=True)
 
 
 
 
 # MPS workaround for https://github.com/pytorch/pytorch/issues/89784
 # MPS workaround for https://github.com/pytorch/pytorch/issues/89784

+ 27 - 4
modules/modelloader.py

@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import os
 import os
 import shutil
 import shutil
 import importlib
 import importlib
@@ -8,6 +10,29 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale
 from modules.paths import script_path, models_path
 from modules.paths import script_path, models_path
 
 
 
 
+def load_file_from_url(
+    url: str,
+    *,
+    model_dir: str,
+    progress: bool = True,
+    file_name: str | None = None,
+) -> str:
+    """Download a file from `url` into `model_dir`, using the file present if possible.
+
+    Returns the path to the downloaded file.
+    """
+    os.makedirs(model_dir, exist_ok=True)
+    if not file_name:
+        parts = urlparse(url)
+        file_name = os.path.basename(parts.path)
+    cached_file = os.path.abspath(os.path.join(model_dir, file_name))
+    if not os.path.exists(cached_file):
+        print(f'Downloading: "{url}" to {cached_file}\n')
+        from torch.hub import download_url_to_file
+        download_url_to_file(url, cached_file, progress=progress)
+    return cached_file
+
+
 def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
 def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
     """
     """
     A one-and done loader to try finding the desired models in specified directories.
     A one-and done loader to try finding the desired models in specified directories.
@@ -46,9 +71,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
 
 
         if model_url is not None and len(output) == 0:
         if model_url is not None and len(output) == 0:
             if download_name is not None:
             if download_name is not None:
-                from basicsr.utils.download_util import load_file_from_url
-                dl = load_file_from_url(model_url, places[0], True, download_name)
-                output.append(dl)
+                output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name))
             else:
             else:
                 output.append(model_url)
                 output.append(model_url)
 
 
@@ -59,7 +82,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
 
 
 
 
 def friendly_name(file: str):
 def friendly_name(file: str):
-    if "http" in file:
+    if file.startswith("http"):
         file = urlparse(file).path
         file = urlparse(file).path
 
 
     file = os.path.basename(file)
     file = os.path.basename(file)

+ 0 - 14
modules/paths.py

@@ -38,17 +38,3 @@ for d, must_exist, what, options in path_dirs:
         else:
         else:
             sys.path.append(d)
             sys.path.append(d)
         paths[what] = d
         paths[what] = d
-
-
-class Prioritize:
-    def __init__(self, name):
-        self.name = name
-        self.path = None
-
-    def __enter__(self):
-        self.path = sys.path.copy()
-        sys.path = [paths[self.name]] + sys.path
-
-    def __exit__(self, exc_type, exc_val, exc_tb):
-        sys.path = self.path
-        self.path = None

+ 4 - 3
modules/postprocessing.py

@@ -9,8 +9,7 @@ from modules.shared import opts
 def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
 def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
     devices.torch_gc()
     devices.torch_gc()
 
 
-    shared.state.begin()
-    shared.state.job = 'extras'
+    shared.state.begin(job="extras")
 
 
     image_data = []
     image_data = []
     image_names = []
     image_names = []
@@ -54,7 +53,9 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
     for image, name in zip(image_data, image_names):
     for image, name in zip(image_data, image_names):
         shared.state.textinfo = name
         shared.state.textinfo = name
 
 
-        existing_pnginfo = image.info or {}
+        parameters, existing_pnginfo = images.read_info_from_image(image)
+        if parameters:
+            existing_pnginfo["parameters"] = parameters
 
 
         pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
         pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
 
 

+ 15 - 8
modules/processing.py

@@ -184,6 +184,8 @@ class StableDiffusionProcessing:
         self.uc = None
         self.uc = None
         self.c = None
         self.c = None
 
 
+        self.user = None
+
     @property
     @property
     def sd_model(self):
     def sd_model(self):
         return shared.sd_model
         return shared.sd_model
@@ -549,7 +551,7 @@ def program_version():
     return res
     return res
 
 
 
 
-def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
+def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False):
     index = position_in_batch + iteration * p.batch_size
     index = position_in_batch + iteration * p.batch_size
 
 
     clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
     clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
@@ -573,7 +575,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
         "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
         "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
         "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
         "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
         "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
         "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
-        "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
+        "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
         "Denoising strength": getattr(p, 'denoising_strength', None),
         "Denoising strength": getattr(p, 'denoising_strength', None),
         "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
         "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
         "Clip skip": None if clip_skip <= 1 else clip_skip,
         "Clip skip": None if clip_skip <= 1 else clip_skip,
@@ -585,13 +587,15 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
         "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
         "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
         **p.extra_generation_params,
         **p.extra_generation_params,
         "Version": program_version() if opts.add_version_to_infotext else None,
         "Version": program_version() if opts.add_version_to_infotext else None,
+        "User": p.user if opts.add_user_name_to_info else None,
     }
     }
 
 
     generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
     generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
 
 
+    prompt_text = p.prompt if use_main_prompt else all_prompts[index]
     negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
     negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
 
 
-    return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
+    return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
 
 
 
 
 def process_images(p: StableDiffusionProcessing) -> Processed:
 def process_images(p: StableDiffusionProcessing) -> Processed:
@@ -602,7 +606,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
 
 
     try:
     try:
         # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
         # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
-        if sd_models.checkpoint_alisases.get(p.override_settings.get('sd_model_checkpoint')) is None:
+        if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
             p.override_settings.pop('sd_model_checkpoint', None)
             p.override_settings.pop('sd_model_checkpoint', None)
             sd_models.reload_model_weights()
             sd_models.reload_model_weights()
 
 
@@ -663,8 +667,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
     else:
     else:
         p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
         p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
 
 
-    def infotext(iteration=0, position_in_batch=0):
-        return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
+    def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):
+        return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
 
 
     if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
     if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
         model_hijack.embedding_db.load_textual_inversion_embeddings()
         model_hijack.embedding_db.load_textual_inversion_embeddings()
@@ -824,7 +828,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
             grid = images.image_grid(output_images, p.batch_size)
             grid = images.image_grid(output_images, p.batch_size)
 
 
             if opts.return_grid:
             if opts.return_grid:
-                text = infotext()
+                text = infotext(use_main_prompt=True)
                 infotexts.insert(0, text)
                 infotexts.insert(0, text)
                 if opts.enable_pnginfo:
                 if opts.enable_pnginfo:
                     grid.info["parameters"] = text
                     grid.info["parameters"] = text
@@ -832,7 +836,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
                 index_of_first_image = 1
                 index_of_first_image = 1
 
 
             if opts.grid_save:
             if opts.grid_save:
-                images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
+                images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
 
 
     if not p.disable_extra_networks and p.extra_network_data:
     if not p.disable_extra_networks and p.extra_network_data:
         extra_networks.deactivate(p, p.extra_network_data)
         extra_networks.deactivate(p, p.extra_network_data)
@@ -1074,6 +1078,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
 
 
         sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
         sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
 
 
+        if self.scripts is not None:
+            self.scripts.before_hr(self)
+
         samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
         samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
 
 
         sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
         sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())

+ 15 - 18
modules/realesrgan_model.py

@@ -2,7 +2,6 @@ import os
 
 
 import numpy as np
 import numpy as np
 from PIL import Image
 from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
 from realesrgan import RealESRGANer
 from realesrgan import RealESRGANer
 
 
 from modules.upscaler import Upscaler, UpscalerData
 from modules.upscaler import Upscaler, UpscalerData
@@ -43,9 +42,10 @@ class UpscalerRealESRGAN(Upscaler):
         if not self.enable:
         if not self.enable:
             return img
             return img
 
 
-        info = self.load_model(path)
-        if not os.path.exists(info.local_data_path):
-            print(f"Unable to load RealESRGAN model: {info.name}")
+        try:
+            info = self.load_model(path)
+        except Exception:
+            errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
             return img
             return img
 
 
         upsampler = RealESRGANer(
         upsampler = RealESRGANer(
@@ -63,20 +63,17 @@ class UpscalerRealESRGAN(Upscaler):
         return image
         return image
 
 
     def load_model(self, path):
     def load_model(self, path):
-        try:
-            info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)
-
-            if info is None:
-                print(f"Unable to find model info: {path}")
-                return None
-
-            if info.local_data_path.startswith("http"):
-                info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True)
-
-            return info
-        except Exception:
-            errors.report("Error making Real-ESRGAN models list", exc_info=True)
-        return None
+        for scaler in self.scalers:
+            if scaler.data_path == path:
+                if scaler.local_data_path.startswith("http"):
+                    scaler.local_data_path = modelloader.load_file_from_url(
+                        scaler.data_path,
+                        model_dir=self.model_download_path,
+                    )
+                if not os.path.exists(scaler.local_data_path):
+                    raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}")
+                return scaler
+        raise ValueError(f"Unable to find model info: {path}")
 
 
     def load_models(self, _):
     def load_models(self, _):
         return get_realesrgan_models(self)
         return get_realesrgan_models(self)

+ 39 - 1
modules/scripts.py

@@ -1,6 +1,7 @@
 import os
 import os
 import re
 import re
 import sys
 import sys
+import inspect
 from collections import namedtuple
 from collections import namedtuple
 
 
 import gradio as gr
 import gradio as gr
@@ -116,6 +117,21 @@ class Script:
 
 
         pass
         pass
 
 
+    def after_extra_networks_activate(self, p, *args, **kwargs):
+        """
+        Calledafter extra networks activation, before conds calculation
+        allow modification of the network after extra networks activation been applied
+        won't be call if p.disable_extra_networks
+
+        **kwargs will have those items:
+          - batch_number - index of current batch, from 0 to number of batches-1
+          - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
+          - seeds - list of seeds for current batch
+          - subseeds - list of subseeds for current batch
+          - extra_network_data - list of ExtraNetworkParams for current stage
+        """
+        pass
+
     def process_batch(self, p, *args, **kwargs):
     def process_batch(self, p, *args, **kwargs):
         """
         """
         Same as process(), but called for every batch.
         Same as process(), but called for every batch.
@@ -186,6 +202,11 @@ class Script:
 
 
         return f'script_{tabname}{title}_{item_id}'
         return f'script_{tabname}{title}_{item_id}'
 
 
+    def before_hr(self, p, *args):
+        """
+        This function is called before hires fix start.
+        """
+        pass
 
 
 current_basedir = paths.script_path
 current_basedir = paths.script_path
 
 
@@ -249,7 +270,7 @@ def load_scripts():
 
 
     def register_scripts_from_module(module):
     def register_scripts_from_module(module):
         for script_class in module.__dict__.values():
         for script_class in module.__dict__.values():
-            if type(script_class) != type:
+            if not inspect.isclass(script_class):
                 continue
                 continue
 
 
             if issubclass(script_class, Script):
             if issubclass(script_class, Script):
@@ -483,6 +504,14 @@ class ScriptRunner:
             except Exception:
             except Exception:
                 errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
                 errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
 
 
+    def after_extra_networks_activate(self, p, **kwargs):
+        for script in self.alwayson_scripts:
+            try:
+                script_args = p.script_args[script.args_from:script.args_to]
+                script.after_extra_networks_activate(p, *script_args, **kwargs)
+            except Exception:
+                errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True)
+
     def process_batch(self, p, **kwargs):
     def process_batch(self, p, **kwargs):
         for script in self.alwayson_scripts:
         for script in self.alwayson_scripts:
             try:
             try:
@@ -548,6 +577,15 @@ class ScriptRunner:
                     self.scripts[si].args_to = args_to
                     self.scripts[si].args_to = args_to
 
 
 
 
+    def before_hr(self, p):
+        for script in self.alwayson_scripts:
+            try:
+                script_args = p.script_args[script.args_from:script.args_to]
+                script.before_hr(p, *script_args)
+            except Exception:
+                errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
+
+
 scripts_txt2img: ScriptRunner = None
 scripts_txt2img: ScriptRunner = None
 scripts_img2img: ScriptRunner = None
 scripts_img2img: ScriptRunner = None
 scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None
 scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None

+ 12 - 7
modules/sd_models.py

@@ -23,7 +23,8 @@ model_dir = "Stable-diffusion"
 model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
 model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
 
 
 checkpoints_list = {}
 checkpoints_list = {}
-checkpoint_alisases = {}
+checkpoint_aliases = {}
+checkpoint_alisases = checkpoint_aliases  # for compatibility with old name
 checkpoints_loaded = collections.OrderedDict()
 checkpoints_loaded = collections.OrderedDict()
 
 
 
 
@@ -66,7 +67,7 @@ class CheckpointInfo:
     def register(self):
     def register(self):
         checkpoints_list[self.title] = self
         checkpoints_list[self.title] = self
         for id in self.ids:
         for id in self.ids:
-            checkpoint_alisases[id] = self
+            checkpoint_aliases[id] = self
 
 
     def calculate_shorthash(self):
     def calculate_shorthash(self):
         self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
         self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
@@ -112,7 +113,7 @@ def checkpoint_tiles():
 
 
 def list_models():
 def list_models():
     checkpoints_list.clear()
     checkpoints_list.clear()
-    checkpoint_alisases.clear()
+    checkpoint_aliases.clear()
 
 
     cmd_ckpt = shared.cmd_opts.ckpt
     cmd_ckpt = shared.cmd_opts.ckpt
     if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):
     if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):
@@ -136,7 +137,7 @@ def list_models():
 
 
 
 
 def get_closet_checkpoint_match(search_string):
 def get_closet_checkpoint_match(search_string):
-    checkpoint_info = checkpoint_alisases.get(search_string, None)
+    checkpoint_info = checkpoint_aliases.get(search_string, None)
     if checkpoint_info is not None:
     if checkpoint_info is not None:
         return checkpoint_info
         return checkpoint_info
 
 
@@ -166,7 +167,7 @@ def select_checkpoint():
     """Raises `FileNotFoundError` if no checkpoints are found."""
     """Raises `FileNotFoundError` if no checkpoints are found."""
     model_checkpoint = shared.opts.sd_model_checkpoint
     model_checkpoint = shared.opts.sd_model_checkpoint
 
 
-    checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
+    checkpoint_info = checkpoint_aliases.get(model_checkpoint, None)
     if checkpoint_info is not None:
     if checkpoint_info is not None:
         return checkpoint_info
         return checkpoint_info
 
 
@@ -247,7 +248,12 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
     _, extension = os.path.splitext(checkpoint_file)
     _, extension = os.path.splitext(checkpoint_file)
     if extension.lower() == ".safetensors":
     if extension.lower() == ".safetensors":
         device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
         device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
-        pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
+
+        if not shared.opts.disable_mmap_load_safetensors:
+            pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
+        else:
+            pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
+            pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
     else:
     else:
         pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
         pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
 
 
@@ -585,7 +591,6 @@ def unload_model_weights(sd_model=None, info=None):
         sd_model = None
         sd_model = None
         gc.collect()
         gc.collect()
         devices.torch_gc()
         devices.torch_gc()
-        torch.cuda.empty_cache()
 
 
     print(f"Unloaded weights {timer.summary()}.")
     print(f"Unloaded weights {timer.summary()}.")
 
 

+ 29 - 6
modules/shared.py

@@ -1,9 +1,11 @@
 import datetime
 import datetime
 import json
 import json
 import os
 import os
+import re
 import sys
 import sys
 import threading
 import threading
 import time
 import time
+import logging
 
 
 import gradio as gr
 import gradio as gr
 import torch
 import torch
@@ -18,6 +20,8 @@ from modules.paths_internal import models_path, script_path, data_path, sd_confi
 from ldm.models.diffusion.ddpm import LatentDiffusion
 from ldm.models.diffusion.ddpm import LatentDiffusion
 from typing import Optional
 from typing import Optional
 
 
+log = logging.getLogger(__name__)
+
 demo = None
 demo = None
 
 
 parser = cmd_args.parser
 parser = cmd_args.parser
@@ -144,12 +148,15 @@ class State:
     def request_restart(self) -> None:
     def request_restart(self) -> None:
         self.interrupt()
         self.interrupt()
         self.server_command = "restart"
         self.server_command = "restart"
+        log.info("Received restart request")
 
 
     def skip(self):
     def skip(self):
         self.skipped = True
         self.skipped = True
+        log.info("Received skip request")
 
 
     def interrupt(self):
     def interrupt(self):
         self.interrupted = True
         self.interrupted = True
+        log.info("Received interrupt request")
 
 
     def nextjob(self):
     def nextjob(self):
         if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
         if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
@@ -173,7 +180,7 @@ class State:
 
 
         return obj
         return obj
 
 
-    def begin(self):
+    def begin(self, job: str = "(unknown)"):
         self.sampling_step = 0
         self.sampling_step = 0
         self.job_count = -1
         self.job_count = -1
         self.processing_has_refined_job_count = False
         self.processing_has_refined_job_count = False
@@ -187,10 +194,13 @@ class State:
         self.interrupted = False
         self.interrupted = False
         self.textinfo = None
         self.textinfo = None
         self.time_start = time.time()
         self.time_start = time.time()
-
+        self.job = job
         devices.torch_gc()
         devices.torch_gc()
+        log.info("Starting job %s", job)
 
 
     def end(self):
     def end(self):
+        duration = time.time() - self.time_start
+        log.info("Ending job %s (%.2f seconds)", self.job, duration)
         self.job = ""
         self.job = ""
         self.job_count = 0
         self.job_count = 0
 
 
@@ -311,6 +321,10 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
     "grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
     "grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
     "grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
     "grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
     "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
     "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
+    "font": OptionInfo("", "Font for image grids that have text"),
+    "grid_text_active_color": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}),
+    "grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}),
+    "grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}),
 
 
     "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
     "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
     "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
     "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
@@ -376,6 +390,7 @@ options_templates.update(options_section(('system', "System"), {
     "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
     "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
     "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
     "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
     "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
     "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
+    "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
 }))
 }))
 
 
 options_templates.update(options_section(('training', "Training"), {
 options_templates.update(options_section(('training', "Training"), {
@@ -470,7 +485,6 @@ options_templates.update(options_section(('ui', "User interface"), {
     "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
     "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
     "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
     "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
     "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
     "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
-    "font": OptionInfo("", "Font for image grids that have text"),
     "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
     "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
     "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
     "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
     "js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
     "js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
@@ -481,6 +495,7 @@ options_templates.update(options_section(('ui', "User interface"), {
     "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
     "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
     "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
     "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
     "keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
     "keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
+    "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
     "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(),
     "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(),
     "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
     "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
     "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
     "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
@@ -493,6 +508,7 @@ options_templates.update(options_section(('ui', "User interface"), {
 options_templates.update(options_section(('infotext', "Infotext"), {
 options_templates.update(options_section(('infotext', "Infotext"), {
     "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
     "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
     "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
     "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
+    "add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
     "add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
     "add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
     "disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
     "disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
     "infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
     "infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
@@ -817,8 +833,12 @@ mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
 mem_mon.start()
 mem_mon.start()
 
 
 
 
+def natural_sort_key(s, regex=re.compile('([0-9]+)')):
+    return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)]
+
+
 def listfiles(dirname):
 def listfiles(dirname):
-    filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=str.lower) if not x.startswith(".")]
+    filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")]
     return [file for file in filenames if os.path.isfile(file)]
     return [file for file in filenames if os.path.isfile(file)]
 
 
 
 
@@ -843,8 +863,11 @@ def walk_files(path, allowed_extensions=None):
     if allowed_extensions is not None:
     if allowed_extensions is not None:
         allowed_extensions = set(allowed_extensions)
         allowed_extensions = set(allowed_extensions)
 
 
-    for root, _, files in os.walk(path, followlinks=True):
-        for filename in files:
+    items = list(os.walk(path, followlinks=True))
+    items = sorted(items, key=lambda x: natural_sort_key(x[0]))
+
+    for root, _, files in items:
+        for filename in sorted(files, key=natural_sort_key):
             if allowed_extensions is not None:
             if allowed_extensions is not None:
                 _, ext = os.path.splitext(filename)
                 _, ext = os.path.splitext(filename)
                 if ext not in allowed_extensions:
                 if ext not in allowed_extensions:

+ 44 - 4
modules/textual_inversion/logging.py

@@ -2,11 +2,51 @@ import datetime
 import json
 import json
 import os
 import os
 
 
-saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "gradient_step", "latent_sampling_method"}
-saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
-saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
+saved_params_shared = {
+    "batch_size",
+    "clip_grad_mode",
+    "clip_grad_value",
+    "create_image_every",
+    "data_root",
+    "gradient_step",
+    "initial_step",
+    "latent_sampling_method",
+    "learn_rate",
+    "log_directory",
+    "model_hash",
+    "model_name",
+    "num_of_dataset_images",
+    "steps",
+    "template_file",
+    "training_height",
+    "training_width",
+}
+saved_params_ti = {
+    "embedding_name",
+    "num_vectors_per_token",
+    "save_embedding_every",
+    "save_image_with_stored_embedding",
+}
+saved_params_hypernet = {
+    "activation_func",
+    "add_layer_norm",
+    "hypernetwork_name",
+    "layer_structure",
+    "save_hypernetwork_every",
+    "use_dropout",
+    "weight_init",
+}
 saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
 saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
-saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"}
+saved_params_previews = {
+    "preview_cfg_scale",
+    "preview_height",
+    "preview_negative_prompt",
+    "preview_prompt",
+    "preview_sampler_index",
+    "preview_seed",
+    "preview_steps",
+    "preview_width",
+}
 
 
 
 
 def save_settings_to_file(log_directory, all_params):
 def save_settings_to_file(log_directory, all_params):

+ 1 - 1
modules/textual_inversion/preprocess.py

@@ -7,7 +7,7 @@ from modules import paths, shared, images, deepbooru
 from modules.textual_inversion import autocrop
 from modules.textual_inversion import autocrop
 
 
 
 
-def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
+def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.15, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
     try:
     try:
         if process_caption:
         if process_caption:
             shared.interrogator.load()
             shared.interrogator.load()

+ 4 - 2
modules/textual_inversion/textual_inversion.py

@@ -1,5 +1,6 @@
 import os
 import os
 from collections import namedtuple
 from collections import namedtuple
+from contextlib import closing
 
 
 import torch
 import torch
 import tqdm
 import tqdm
@@ -584,8 +585,9 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
 
 
                     preview_text = p.prompt
                     preview_text = p.prompt
 
 
-                    processed = processing.process_images(p)
-                    image = processed.images[0] if len(processed.images) > 0 else None
+                    with closing(p):
+                        processed = processing.process_images(p)
+                        image = processed.images[0] if len(processed.images) > 0 else None
 
 
                     if unload:
                     if unload:
                         shared.sd_model.first_stage_model.to(devices.cpu)
                         shared.sd_model.first_stage_model.to(devices.cpu)

+ 10 - 7
modules/txt2img.py

@@ -1,13 +1,15 @@
+from contextlib import closing
+
 import modules.scripts
 import modules.scripts
 from modules import sd_samplers, processing
 from modules import sd_samplers, processing
 from modules.generation_parameters_copypaste import create_override_settings_dict
 from modules.generation_parameters_copypaste import create_override_settings_dict
 from modules.shared import opts, cmd_opts
 from modules.shared import opts, cmd_opts
 import modules.shared as shared
 import modules.shared as shared
 from modules.ui import plaintext_to_html
 from modules.ui import plaintext_to_html
+import gradio as gr
 
 
 
 
-
-def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args):
+def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
     override_settings = create_override_settings_dict(override_settings_texts)
     override_settings = create_override_settings_dict(override_settings_texts)
 
 
     p = processing.StableDiffusionProcessingTxt2Img(
     p = processing.StableDiffusionProcessingTxt2Img(
@@ -48,15 +50,16 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
     p.scripts = modules.scripts.scripts_txt2img
     p.scripts = modules.scripts.scripts_txt2img
     p.script_args = args
     p.script_args = args
 
 
+    p.user = request.username
+
     if cmd_opts.enable_console_prompts:
     if cmd_opts.enable_console_prompts:
         print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
         print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
 
 
-    processed = modules.scripts.scripts_txt2img.run(p, *args)
-
-    if processed is None:
-        processed = processing.process_images(p)
+    with closing(p):
+        processed = modules.scripts.scripts_txt2img.run(p, *args)
 
 
-    p.close()
+        if processed is None:
+            processed = processing.process_images(p)
 
 
     shared.total_tqdm.clear()
     shared.total_tqdm.clear()
 
 

+ 10 - 3
modules/ui.py

@@ -155,7 +155,7 @@ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_di
             img = Image.open(image)
             img = Image.open(image)
             filename = os.path.basename(image)
             filename = os.path.basename(image)
             left, _ = os.path.splitext(filename)
             left, _ = os.path.splitext(filename)
-            print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a'))
+            print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a', encoding='utf-8'))
 
 
         return [gr.update(), None]
         return [gr.update(), None]
 
 
@@ -733,6 +733,10 @@ def create_ui():
                         img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
                         img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
                         img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
                         img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
                         img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
                         img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
+                        with gr.Accordion("PNG info", open=False):
+                            img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info")
+                            img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir")
+                            img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.")
 
 
                     img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
                     img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
 
 
@@ -773,7 +777,7 @@ def create_ui():
                                 selected_scale_tab = gr.State(value=0)
                                 selected_scale_tab = gr.State(value=0)
 
 
                                 with gr.Tabs():
                                 with gr.Tabs():
-                                    with gr.Tab(label="Resize to") as tab_scale_to:
+                                    with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to:
                                         with FormRow():
                                         with FormRow():
                                             with gr.Column(elem_id="img2img_column_size", scale=4):
                                             with gr.Column(elem_id="img2img_column_size", scale=4):
                                                 width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
                                                 width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
@@ -782,7 +786,7 @@ def create_ui():
                                                 res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
                                                 res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
                                                 detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn")
                                                 detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn")
 
 
-                                    with gr.Tab(label="Resize by") as tab_scale_by:
+                                    with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by:
                                         scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
                                         scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
 
 
                                         with FormRow():
                                         with FormRow():
@@ -934,6 +938,9 @@ def create_ui():
                     img2img_batch_output_dir,
                     img2img_batch_output_dir,
                     img2img_batch_inpaint_mask_dir,
                     img2img_batch_inpaint_mask_dir,
                     override_settings,
                     override_settings,
+                    img2img_batch_use_png_info,
+                    img2img_batch_png_info_props,
+                    img2img_batch_png_info_dir,
                 ] + custom_inputs,
                 ] + custom_inputs,
                 outputs=[
                 outputs=[
                     img2img_gallery,
                     img2img_gallery,

+ 23 - 6
modules/ui_extensions.py

@@ -138,7 +138,10 @@ def extension_table():
     <table id="extensions">
     <table id="extensions">
         <thead>
         <thead>
             <tr>
             <tr>
-                <th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th>
+                <th>
+                    <input class="gr-check-radio gr-checkbox all_extensions_toggle" type="checkbox" {'checked="checked"' if all(ext.enabled for ext in extensions.extensions) else ''} onchange="toggle_all_extensions(event)" />
+                    <abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr>
+                </th>
                 <th>URL</th>
                 <th>URL</th>
                 <th>Branch</th>
                 <th>Branch</th>
                 <th>Version</th>
                 <th>Version</th>
@@ -170,7 +173,7 @@ def extension_table():
 
 
         code += f"""
         code += f"""
             <tr>
             <tr>
-                <td><label{style}><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
+                <td><label{style}><input class="gr-check-radio gr-checkbox extension_toggle" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''} onchange="toggle_extension(event)" />{html.escape(ext.name)}</label></td>
                 <td>{remote}</td>
                 <td>{remote}</td>
                 <td>{ext.branch}</td>
                 <td>{ext.branch}</td>
                 <td>{version_link}</td>
                 <td>{version_link}</td>
@@ -421,9 +424,19 @@ sort_ordering = [
     (False, lambda x: x.get('name', 'z')),
     (False, lambda x: x.get('name', 'z')),
     (True, lambda x: x.get('name', 'z')),
     (True, lambda x: x.get('name', 'z')),
     (False, lambda x: 'z'),
     (False, lambda x: 'z'),
+    (True, lambda x: x.get('commit_time', '')),
+    (True, lambda x: x.get('created_at', '')),
+    (True, lambda x: x.get('stars', 0)),
 ]
 ]
 
 
 
 
+def get_date(info: dict, key):
+    try:
+        return datetime.strptime(info.get(key), "%Y-%m-%dT%H:%M:%SZ").strftime("%Y-%m-%d")
+    except (ValueError, TypeError):
+        return ''
+
+
 def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""):
 def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""):
     extlist = available_extensions["extensions"]
     extlist = available_extensions["extensions"]
     installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
     installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
@@ -448,7 +461,10 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
 
 
     for ext in sorted(extlist, key=sort_function, reverse=sort_reverse):
     for ext in sorted(extlist, key=sort_function, reverse=sort_reverse):
         name = ext.get("name", "noname")
         name = ext.get("name", "noname")
+        stars = int(ext.get("stars", 0))
         added = ext.get('added', 'unknown')
         added = ext.get('added', 'unknown')
+        update_time = get_date(ext, 'commit_time')
+        create_time = get_date(ext, 'created_at')
         url = ext.get("url", None)
         url = ext.get("url", None)
         description = ext.get("description", "")
         description = ext.get("description", "")
         extension_tags = ext.get("tags", [])
         extension_tags = ext.get("tags", [])
@@ -475,7 +491,8 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
         code += f"""
         code += f"""
             <tr>
             <tr>
                 <td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td>
                 <td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td>
-                <td>{html.escape(description)}<p class="info"><span class="date_added">Added: {html.escape(added)}</span></p></td>
+                <td>{html.escape(description)}<p class="info">
+                <span class="date_added">Update: {html.escape(update_time)}  Added: {html.escape(added)}  Created: {html.escape(create_time)}</span><span class="star_count">stars: <b>{stars}</b></a></p></td>
                 <td>{install_code}</td>
                 <td>{install_code}</td>
             </tr>
             </tr>
 
 
@@ -559,7 +576,7 @@ def create_ui():
 
 
                 with gr.Row():
                 with gr.Row():
                     hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
                     hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
-                    sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index")
+                    sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index")
 
 
                 with gr.Row():
                 with gr.Row():
                     search_extensions_text = gr.Text(label="Search").style(container=False)
                     search_extensions_text = gr.Text(label="Search").style(container=False)
@@ -568,9 +585,9 @@ def create_ui():
                 available_extensions_table = gr.HTML()
                 available_extensions_table = gr.HTML()
 
 
                 refresh_available_extensions_button.click(
                 refresh_available_extensions_button.click(
-                    fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
+                    fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update(), gr.update()]),
                     inputs=[available_extensions_index, hide_tags, sort_column],
                     inputs=[available_extensions_index, hide_tags, sort_column],
-                    outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result, search_extensions_text],
+                    outputs=[available_extensions_index, available_extensions_table, hide_tags, search_extensions_text, install_result],
                 )
                 )
 
 
                 install_extension_button.click(
                 install_extension_button.click(

+ 4 - 4
modules/ui_extra_networks.py

@@ -30,8 +30,8 @@ def fetch_file(filename: str = ""):
         raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
         raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
 
 
     ext = os.path.splitext(filename)[1].lower()
     ext = os.path.splitext(filename)[1].lower()
-    if ext not in (".png", ".jpg", ".jpeg", ".webp"):
-        raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
+    if ext not in (".png", ".jpg", ".jpeg", ".webp", ".gif"):
+        raise ValueError(f"File cannot be fetched: {filename}. Only png, jpg, webp, and gif.")
 
 
     # would profit from returning 304
     # would profit from returning 304
     return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
     return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
@@ -90,8 +90,8 @@ class ExtraNetworksPage:
 
 
         subdirs = {}
         subdirs = {}
         for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
         for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
-            for root, dirs, _ in os.walk(parentdir, followlinks=True):
-                for dirname in dirs:
+            for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])):
+                for dirname in sorted(dirs, key=shared.natural_sort_key):
                     x = os.path.join(root, dirname)
                     x = os.path.join(root, dirname)
 
 
                     if not os.path.isdir(x):
                     if not os.path.isdir(x):

+ 14 - 7
modules/ui_settings.py

@@ -260,13 +260,20 @@ class UiSettings:
             component = self.component_dict[k]
             component = self.component_dict[k]
             info = opts.data_labels[k]
             info = opts.data_labels[k]
 
 
-            change_handler = component.release if hasattr(component, 'release') else component.change
-            change_handler(
-                fn=lambda value, k=k: self.run_settings_single(value, key=k),
-                inputs=[component],
-                outputs=[component, self.text_settings],
-                show_progress=info.refresh is not None,
-            )
+            if isinstance(component, gr.Textbox):
+                methods = [component.submit, component.blur]
+            elif hasattr(component, 'release'):
+                methods = [component.release]
+            else:
+                methods = [component.change]
+
+            for method in methods:
+                method(
+                    fn=lambda value, k=k: self.run_settings_single(value, key=k),
+                    inputs=[component],
+                    outputs=[component, self.text_settings],
+                    show_progress=info.refresh is not None,
+                )
 
 
         button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
         button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
         button_set_checkpoint.click(
         button_set_checkpoint.click(

+ 15 - 2
style.css

@@ -704,11 +704,24 @@ table.popup-table .link{
     margin: 0;
     margin: 0;
 }
 }
 
 
-#available_extensions .date_added{
-    opacity: 0.85;
+#available_extensions .info{
+    margin: 0.5em 0;
+    display: flex;
+    margin-top: auto;
+    opacity: 0.80;
     font-size: 90%;
     font-size: 90%;
 }
 }
 
 
+#available_extensions .date_added{
+    margin-right: auto;
+    display: inline-block;
+}
+
+#available_extensions .star_count{
+    margin-left: auto;
+    display: inline-block;
+}
+
 /* replace original footer with ours */
 /* replace original footer with ours */
 
 
 footer {
 footer {

+ 19 - 15
webui.py

@@ -11,13 +11,24 @@ import json
 from threading import Thread
 from threading import Thread
 from typing import Iterable
 from typing import Iterable
 
 
-from fastapi import FastAPI, Response
+from fastapi import FastAPI
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.gzip import GZipMiddleware
 from fastapi.middleware.gzip import GZipMiddleware
 from packaging import version
 from packaging import version
 
 
 import logging
 import logging
 
 
+# We can't use cmd_opts for this because it will not have been initialized at this point.
+log_level = os.environ.get("SD_WEBUI_LOG_LEVEL")
+if log_level:
+    log_level = getattr(logging, log_level.upper(), None) or logging.INFO
+    logging.basicConfig(
+        level=log_level,
+        format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
+        datefmt='%Y-%m-%d %H:%M:%S',
+    )
+
+logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR)  # sshh...
 logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
 logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
 
 
 from modules import paths, timer, import_hook, errors, devices  # noqa: F401
 from modules import paths, timer, import_hook, errors, devices  # noqa: F401
@@ -32,7 +43,7 @@ warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvisi
 
 
 startup_timer.record("import torch")
 startup_timer.record("import torch")
 
 
-import gradio
+import gradio  # noqa: F401
 startup_timer.record("import gradio")
 startup_timer.record("import gradio")
 
 
 import ldm.modules.encoders.modules  # noqa: F401
 import ldm.modules.encoders.modules  # noqa: F401
@@ -359,12 +370,11 @@ def api_only():
     modules.script_callbacks.app_started_callback(None, app)
     modules.script_callbacks.app_started_callback(None, app)
 
 
     print(f"Startup time: {startup_timer.summary()}.")
     print(f"Startup time: {startup_timer.summary()}.")
-    api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
-
-
-def stop_route(request):
-    shared.state.server_command = "stop"
-    return Response("Stopping.")
+    api.launch(
+        server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1",
+        port=cmd_opts.port if cmd_opts.port else 7861,
+        root_path = f"/{cmd_opts.subpath}"
+    )
 
 
 
 
 def webui():
 def webui():
@@ -403,9 +413,8 @@ def webui():
                 "docs_url": "/docs",
                 "docs_url": "/docs",
                 "redoc_url": "/redoc",
                 "redoc_url": "/redoc",
             },
             },
+            root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "",
         )
         )
-        if cmd_opts.add_stop_route:
-            app.add_route("/_stop", stop_route, methods=["POST"])
 
 
         # after initial launch, disable --autolaunch for subsequent restarts
         # after initial launch, disable --autolaunch for subsequent restarts
         cmd_opts.autolaunch = False
         cmd_opts.autolaunch = False
@@ -436,11 +445,6 @@ def webui():
         timer.startup_record = startup_timer.dump()
         timer.startup_record = startup_timer.dump()
         print(f"Startup time: {startup_timer.summary()}.")
         print(f"Startup time: {startup_timer.summary()}.")
 
 
-        if cmd_opts.subpath:
-            redirector = FastAPI()
-            redirector.get("/")
-            gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}")
-
         try:
         try:
             while True:
             while True:
                 server_command = shared.state.wait_for_server_command(timeout=5)
                 server_command = shared.state.wait_for_server_command(timeout=5)

+ 11 - 5
webui.sh

@@ -4,26 +4,28 @@
 # change the variables in webui-user.sh instead #
 # change the variables in webui-user.sh instead #
 #################################################
 #################################################
 
 
+SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
+
 # If run from macOS, load defaults from webui-macos-env.sh
 # If run from macOS, load defaults from webui-macos-env.sh
 if [[ "$OSTYPE" == "darwin"* ]]; then
 if [[ "$OSTYPE" == "darwin"* ]]; then
-    if [[ -f webui-macos-env.sh ]]
+    if [[ -f "$SCRIPT_DIR"/webui-macos-env.sh ]]
         then
         then
-        source ./webui-macos-env.sh
+        source "$SCRIPT_DIR"/webui-macos-env.sh
     fi
     fi
 fi
 fi
 
 
 # Read variables from webui-user.sh
 # Read variables from webui-user.sh
 # shellcheck source=/dev/null
 # shellcheck source=/dev/null
-if [[ -f webui-user.sh ]]
+if [[ -f "$SCRIPT_DIR"/webui-user.sh ]]
 then
 then
-    source ./webui-user.sh
+    source "$SCRIPT_DIR"/webui-user.sh
 fi
 fi
 
 
 # Set defaults
 # Set defaults
 # Install directory without trailing slash
 # Install directory without trailing slash
 if [[ -z "${install_dir}" ]]
 if [[ -z "${install_dir}" ]]
 then
 then
-    install_dir="$(pwd)"
+    install_dir="$SCRIPT_DIR"
 fi
 fi
 
 
 # Name of the subdirectory (defaults to stable-diffusion-webui)
 # Name of the subdirectory (defaults to stable-diffusion-webui)
@@ -131,6 +133,10 @@ case "$gpu_info" in
     ;;
     ;;
     *"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0
     *"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0
     ;;
     ;;
+    *"Navi 3"*) [[ -z "${TORCH_COMMAND}" ]] && \
+        export TORCH_COMMAND="pip install --pre torch==2.1.0.dev-20230614+rocm5.5 torchvision==0.16.0.dev-20230614+rocm5.5 --index-url https://download.pytorch.org/whl/nightly/rocm5.5"
+        # Navi 3 needs at least 5.5 which is only on the nightly chain
+    ;;
     *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0
     *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0
         printf "\n%s\n" "${delimiter}"
         printf "\n%s\n" "${delimiter}"
         printf "Experimental support for Renoir: make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half"
         printf "Experimental support for Renoir: make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half"