Browse Source

Merge branch 'dev' into improve-frontend-responsiveness

AUTOMATIC1111 2 years ago
parent
commit
04b4508a66
100 changed files with 2310 additions and 1075 deletions
  1. 21 0
      .github/ISSUE_TEMPLATE/bug_report.yml
  2. 25 18
      .github/workflows/on_pull_request.yaml
  3. 6 0
      .github/workflows/run_tests.yaml
  4. 2 1
      .gitignore
  5. 120 0
      CHANGELOG.md
  6. 11 3
      README.md
  7. 5 5
      environment-wsl2.yaml
  8. 6 7
      extensions-builtin/LDSR/ldsr_model_arch.py
  9. 15 8
      extensions-builtin/LDSR/scripts/ldsr_model.py
  10. 17 11
      extensions-builtin/LDSR/sd_hijack_autoencoder.py
  11. 30 36
      extensions-builtin/LDSR/sd_hijack_ddpm_v1.py
  12. 2 1
      extensions-builtin/Lora/extra_networks_lora.py
  13. 121 24
      extensions-builtin/Lora/lora.py
  14. 28 2
      extensions-builtin/Lora/scripts/lora_script.py
  15. 7 1
      extensions-builtin/Lora/ui_extra_networks_lora.py
  16. 79 16
      extensions-builtin/ScuNET/scripts/scunet_model.py
  17. 7 4
      extensions-builtin/ScuNET/scunet_model_arch.py
  18. 3 4
      extensions-builtin/SwinIR/scripts/swinir_model.py
  19. 3 3
      extensions-builtin/SwinIR/swinir_model_arch.py
  20. 29 29
      extensions-builtin/SwinIR/swinir_model_arch_v2.py
  21. 30 91
      extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js
  22. 1 1
      html/extra-networks-card.html
  23. 26 0
      html/licenses.html
  24. 14 19
      javascript/aspectRatioOverlay.js
  25. 6 17
      javascript/contextMenus.js
  26. 35 11
      javascript/edit-attention.js
  27. 28 6
      javascript/extensions.js
  28. 48 22
      javascript/extraNetworks.js
  29. 4 4
      javascript/generationParams.js
  30. 46 26
      javascript/hints.js
  31. 6 10
      javascript/hires_fix.js
  32. 6 7
      javascript/imageMaskFix.js
  33. 0 1
      javascript/imageParams.js
  34. 9 10
      javascript/imageviewer.js
  35. 57 0
      javascript/imageviewerGamepad.js
  36. 44 32
      javascript/localization.js
  37. 3 3
      javascript/notification.js
  38. 5 6
      javascript/progressbar.js
  39. 94 14
      javascript/ui.js
  40. 62 0
      javascript/ui_settings_hints.js
  41. 58 58
      launch.py
  42. BIN
      modules/Roboto-Regular.ttf
  43. 102 88
      modules/api/api.py
  44. 22 4
      modules/api/models.py
  45. 4 2
      modules/call_queue.py
  46. 4 1
      modules/cmd_args.py
  47. 11 13
      modules/codeformer/codeformer_arch.py
  48. 21 23
      modules/codeformer/vqgan_arch.py
  49. 2 4
      modules/codeformer_model.py
  50. 202 0
      modules/config_states.py
  51. 1 2
      modules/deepbooru.py
  52. 7 3
      modules/devices.py
  53. 10 11
      modules/esrgan_model.py
  54. 12 11
      modules/esrgan_model_arch.py
  55. 39 12
      modules/extensions.py
  56. 1 1
      modules/extra_networks.py
  57. 4 3
      modules/extra_networks_hypernet.py
  58. 51 5
      modules/extras.py
  59. 19 9
      modules/generation_parameters_copypaste.py
  60. 1 1
      modules/gfpgan_model.py
  61. 2 2
      modules/hashes.py
  62. 14 15
      modules/hypernetworks/hypernetwork.py
  63. 2 4
      modules/hypernetworks/ui.py
  64. 88 51
      modules/images.py
  65. 16 10
      modules/img2img.py
  66. 7 8
      modules/interrogate.py
  67. 2 2
      modules/localization.py
  68. 8 4
      modules/mac_specific.py
  69. 1 1
      modules/masking.py
  70. 21 43
      modules/modelloader.py
  71. 26 30
      modules/models/diffusion/ddpm_edit.py
  72. 1 1
      modules/models/diffusion/uni_pc/__init__.py
  73. 2 1
      modules/models/diffusion/uni_pc/sampler.py
  74. 46 40
      modules/models/diffusion/uni_pc/uni_pc.py
  75. 14 2
      modules/ngrok.py
  76. 3 3
      modules/paths.py
  77. 11 2
      modules/paths_internal.py
  78. 7 2
      modules/postprocessing.py
  79. 94 16
      modules/processing.py
  80. 32 2
      modules/progress.py
  81. 15 12
      modules/prompt_parser.py
  82. 17 7
      modules/realesrgan_model.py
  83. 9 8
      modules/safe.py
  84. 53 7
      modules/script_callbacks.py
  85. 0 1
      modules/script_loading.py
  86. 36 8
      modules/scripts.py
  87. 1 1
      modules/scripts_auto_postprocessing.py
  88. 4 4
      modules/scripts_postprocessing.py
  89. 1 1
      modules/sd_disable_initialization.py
  90. 13 10
      modules/sd_hijack.py
  91. 1 1
      modules/sd_hijack_clip.py
  92. 2 1
      modules/sd_hijack_clip_old.py
  93. 2 8
      modules/sd_hijack_inpainting.py
  94. 2 5
      modules/sd_hijack_ip2p.py
  95. 26 24
      modules/sd_hijack_optimizations.py
  96. 1 1
      modules/sd_hijack_unet.py
  97. 0 2
      modules/sd_hijack_xlmr.py
  98. 89 27
      modules/sd_models.py
  99. 1 2
      modules/sd_models_config.py
  100. 8 2
      modules/sd_samplers.py

+ 21 - 0
.github/ISSUE_TEMPLATE/bug_report.yml

@@ -47,6 +47,15 @@ body:
       description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)
     validations:
       required: true
+  - type: dropdown
+    id: py-version
+    attributes:
+      label: What Python version are you running on ?
+      multiple: false
+      options:
+        - Python 3.10.x
+        - Python 3.11.x (above, no supported yet)
+        - Python 3.9.x (below, no recommended)
   - type: dropdown
     id: platforms
     attributes:
@@ -59,6 +68,18 @@ body:
         - iOS
         - Android
         - Other/Cloud
+  - type: dropdown
+    id: device
+    attributes:
+        label: What device are you running WebUI on?
+        multiple: true
+        options:
+        - Nvidia GPUs (RTX 20 above)
+        - Nvidia GPUs (GTX 16 below)
+        - AMD GPUs (RX 6000 above)
+        - AMD GPUs (RX 5000 below)
+        - CPU
+        - Other GPUs
   - type: dropdown
     id: browsers
     attributes:

+ 25 - 18
.github/workflows/on_pull_request.yaml

@@ -18,22 +18,29 @@ jobs:
     steps:
       - name: Checkout Code
         uses: actions/checkout@v3
-      - name: Set up Python 3.10
-        uses: actions/setup-python@v4
+      - uses: actions/setup-python@v4
         with:
-          python-version: 3.10.6
-          cache: pip
-          cache-dependency-path: |
-            **/requirements*txt
-      - name: Install PyLint
-        run: | 
-          python -m pip install --upgrade pip
-          pip install pylint
-      # This lets PyLint check to see if it can resolve imports
-      - name: Install dependencies
-        run: |
-          export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit"
-          python launch.py
-      - name: Analysing the code with pylint
-        run: |
-          pylint $(git ls-files '*.py')
+          python-version: 3.11
+          # NB: there's no cache: pip here since we're not installing anything
+          #     from the requirements.txt file(s) in the repository; it's faster
+          #     not to have GHA download an (at the time of writing) 4 GB cache
+          #     of PyTorch and other dependencies.
+      - name: Install Ruff
+        run: pip install ruff==0.0.265
+      - name: Run Ruff
+        run: ruff .
+
+# The rest are currently disabled pending fixing of e.g. installing the torch dependency.
+
+#      - name: Install PyLint
+#        run: |
+#          python -m pip install --upgrade pip
+#          pip install pylint
+#      # This lets PyLint check to see if it can resolve imports
+#      - name: Install dependencies
+#        run: |
+#          export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit"
+#          python launch.py
+#      - name: Analysing the code with pylint
+#        run: |
+#          pylint $(git ls-files '*.py')

+ 6 - 0
.github/workflows/run_tests.yaml

@@ -17,8 +17,14 @@ jobs:
           cache: pip
           cache-dependency-path: |
             **/requirements*txt
+            launch.py
       - name: Run tests
         run: python launch.py --tests test --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
+        env:
+          PIP_DISABLE_PIP_VERSION_CHECK: "1"
+          PIP_PROGRESS_BAR: "off"
+          TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu
+          WEBUI_LAUNCH_LIVE_OUTPUT: "1"
       - name: Upload main app stdout-stderr
         uses: actions/upload-artifact@v3
         if: always()

+ 2 - 1
.gitignore

@@ -32,4 +32,5 @@ notification.mp3
 /extensions
 /test/stdout.txt
 /test/stderr.txt
-/cache.json
+/cache.json*
+/config_states/

+ 120 - 0
CHANGELOG.md

@@ -0,0 +1,120 @@
+## 1.2.1
+
+### Features:
+ * add an option to always refer to lora by filenames
+
+### Bug Fixes:
+ * never refer to lora by an alias if multiple loras have same alias or the alias is called none
+ * fix upscalers disappearing after the user reloads UI
+ * allow bf16 in safe unpickler (resolves problems with loading some loras)
+ * allow web UI to be ran fully offline
+ * fix localizations not working
+ * fix error for loras: 'LatentDiffusion' object has no attribute 'lora_layer_mapping'
+
+## 1.2.0
+
+### Features:
+ * do not wait for stable diffusion model to load at startup
+ * add filename patterns: [denoising]
+ * directory hiding for extra networks: dirs starting with . will hide their cards on extra network tabs unless specifically searched for
+ * Lora: for the `<...>` text in prompt, use name of Lora that is in the metdata of the file, if present, instead of filename (both can be used to activate lora)
+ * Lora: read infotext params from kohya-ss's extension parameters if they are present and if his extension is not active
+ * Lora: Fix some Loras not working (ones that have 3x3 convolution layer)
+ * Lora: add an option to use old method of applying loras (producing same results as with kohya-ss)
+ * add version to infotext, footer and console output when starting
+ * add links to wiki for filename pattern settings
+ * add extended info for quicksettings setting and use multiselect input instead of a text field
+
+### Minor:
+ * gradio bumped to 3.29.0
+ * torch bumped to 2.0.1
+ * --subpath option for gradio for use with reverse proxy
+ * linux/OSX: use existing virtualenv if already active (the VIRTUAL_ENV environment variable)
+ * possible frontend optimization: do not apply localizations if there are none
+ * Add extra `None` option for VAE in XYZ plot
+ * print error to console when batch processing in img2img fails
+ * create HTML for extra network pages only on demand
+ * allow directories starting with . to still list their models for lora, checkpoints, etc
+ * put infotext options into their own category in settings tab
+ * do not show licenses page when user selects Show all pages in settings
+
+### Extensions:
+ * Tooltip localization support
+ * Add api method to get LoRA models with prompt
+
+### Bug Fixes:
+ * re-add /docs endpoint
+ * fix gamepad navigation
+ * make the lightbox fullscreen image function properly
+ * fix squished thumbnails in extras tab
+ * keep "search" filter for extra networks when user refreshes the tab (previously it showed everthing after you refreshed)
+ * fix webui showing the same image if you configure the generation to always save results into same file
+ * fix bug with upscalers not working properly
+ * Fix MPS on PyTorch 2.0.1, Intel Macs
+ * make it so that custom context menu from contextMenu.js only disappears after user's click, ignoring non-user click events
+ * prevent Reload UI button/link from reloading the page when it's not yet ready
+ * fix prompts from file script failing to read contents from a drag/drop file
+
+
+## 1.1.1
+### Bug Fixes:
+ * fix an error that prevents running webui on torch<2.0 without --disable-safe-unpickle
+
+## 1.1.0
+### Features:
+ * switch to torch 2.0.0 (except for AMD GPUs)
+ * visual improvements to custom code scripts
+ * add filename patterns: [clip_skip], [hasprompt<>], [batch_number], [generation_number]
+ * add support for saving init images in img2img, and record their hashes in infotext for reproducability
+ * automatically select current word when adjusting weight with ctrl+up/down
+ * add dropdowns for X/Y/Z plot
+ * setting: Stable Diffusion/Random number generator source: makes it possible to make images generated from a given manual seed consistent across different GPUs
+ * support Gradio's theme API
+ * use TCMalloc on Linux by default; possible fix for memory leaks
+ * (optimization) option to remove negative conditioning at low sigma values #9177
+ * embed model merge metadata in .safetensors file
+ * extension settings backup/restore feature #9169
+ * add "resize by" and "resize to" tabs to img2img
+ * add option "keep original size" to textual inversion images preprocess
+ * image viewer scrolling via analog stick
+ * button to restore the progress from session lost / tab reload
+
+### Minor:
+ * gradio bumped to 3.28.1
+ * in extra tab, change extras "scale to" to sliders
+ * add labels to tool buttons to make it possible to hide them
+ * add tiled inference support for ScuNET
+ * add branch support for extension installation
+ * change linux installation script to insall into current directory rather than /home/username
+ * sort textual inversion embeddings by name (case insensitive)
+ * allow styles.csv to be symlinked or mounted in docker
+ * remove the "do not add watermark to images" option
+ * make selected tab configurable with UI config
+ * extra networks UI in now fixed height and scrollable
+ * add disable_tls_verify arg for use with self-signed certs
+
+### Extensions:
+ * Add reload callback
+ * add is_hr_pass field for processing
+
+### Bug Fixes:
+ * fix broken batch image processing on 'Extras/Batch Process' tab
+ * add "None" option to extra networks dropdowns
+ * fix FileExistsError for CLIP Interrogator
+ * fix /sdapi/v1/txt2img endpoint not working on Linux #9319
+ * fix disappearing live previews and progressbar during slow tasks
+ * fix fullscreen image view not working properly in some cases
+ * prevent alwayson_scripts args param resizing script_arg list when they are inserted in it
+ * fix prompt schedule for second order samplers
+ * fix image mask/composite for weird resolutions #9628
+ * use correct images for previews when using AND (see #9491)
+ * one broken image in img2img batch won't stop all processing
+ * fix image orientation bug in train/preprocess
+ * fix Ngrok recreating tunnels every reload
+ * fix --realesrgan-models-path and --ldsr-models-path not working
+ * fix --skip-install not working
+ * outpainting Mk2 & Poorman should use the SAMPLE file format to save images, not GRID file format
+ * do not fail all Loras if some have failed to load when making a picture
+
+## 1.0.0
+  * everything

+ 11 - 3
README.md

@@ -99,8 +99,14 @@ Alternatively, use online services (like Google Colab):
 
 - [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
 
+### Installation on Windows 10/11 with NVidia-GPUs using release package
+1. Download `sd.webui.zip` from [v1.0.0-pre](https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre) and extract it's contents.
+2. Run `update.bat`.
+3. Run `run.bat`.
+> For more details see [Install-and-Run-on-NVidia-GPUs](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs)
+
 ### Automatic Installation on Windows
-1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH".
+1. Install [Python 3.10.6](https://www.python.org/downloads/release/python-3106/) (Newer version of Python does not support torch), checking "Add Python to PATH".
 2. Install [git](https://git-scm.com/download/win).
 3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`.
 4. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
@@ -115,11 +121,12 @@ sudo dnf install wget git python3
 # Arch-based:
 sudo pacman -S wget git python3
 ```
-2. To install in `/home/$(whoami)/stable-diffusion-webui/`, run:
+2. Navigate to the directory you would like the webui to be installed and execute the following command:
 ```bash
 bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh)
 ```
 3. Run `webui.sh`.
+4. Check `webui-user.sh` for options.
 ### Installation on Apple Silicon
 
 Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Installation-on-Apple-Silicon).
@@ -157,5 +164,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
 - Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix
 - Security advice - RyotaK
 - UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
+- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
 - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
-- (You)
+- (You)

+ 5 - 5
environment-wsl2.yaml

@@ -4,8 +4,8 @@ channels:
   - defaults
 dependencies:
   - python=3.10
-  - pip=22.2.2
-  - cudatoolkit=11.3
-  - pytorch=1.12.1
-  - torchvision=0.13.1
-  - numpy=1.23.1
+  - pip=23.0
+  - cudatoolkit=11.8
+  - pytorch=2.0
+  - torchvision=0.15
+  - numpy=1.23

+ 6 - 7
extensions-builtin/LDSR/ldsr_model_arch.py

@@ -88,7 +88,7 @@ class LDSR:
 
         x_t = None
         logs = None
-        for n in range(n_runs):
+        for _ in range(n_runs):
             if custom_shape is not None:
                 x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
                 x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
@@ -110,7 +110,6 @@ class LDSR:
         diffusion_steps = int(steps)
         eta = 1.0
 
-        down_sample_method = 'Lanczos'
 
         gc.collect()
         if torch.cuda.is_available:
@@ -131,11 +130,11 @@ class LDSR:
             im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
         else:
             print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
-        
+
         # pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
         pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
         im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
-        
+
         logs = self.run(model["model"], im_padded, diffusion_steps, eta)
 
         sample = logs["sample"]
@@ -158,7 +157,7 @@ class LDSR:
 
 
 def get_cond(selected_path):
-    example = dict()
+    example = {}
     up_f = 4
     c = selected_path.convert('RGB')
     c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
@@ -196,7 +195,7 @@ def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_s
 @torch.no_grad()
 def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
                               corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
-    log = dict()
+    log = {}
 
     z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
                                         return_first_stage_outputs=True,
@@ -244,7 +243,7 @@ def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize
         x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
         log["sample_noquant"] = x_sample_noquant
         log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
-    except:
+    except Exception:
         pass
 
     log["sample"] = x_sample

+ 15 - 8
extensions-builtin/LDSR/scripts/ldsr_model.py

@@ -7,7 +7,8 @@ from basicsr.utils.download_util import load_file_from_url
 from modules.upscaler import Upscaler, UpscalerData
 from ldsr_model_arch import LDSR
 from modules import shared, script_callbacks
-import sd_hijack_autoencoder, sd_hijack_ddpm_v1
+import sd_hijack_autoencoder  # noqa: F401
+import sd_hijack_ddpm_v1  # noqa: F401
 
 
 class UpscalerLDSR(Upscaler):
@@ -25,22 +26,28 @@ class UpscalerLDSR(Upscaler):
         yaml_path = os.path.join(self.model_path, "project.yaml")
         old_model_path = os.path.join(self.model_path, "model.pth")
         new_model_path = os.path.join(self.model_path, "model.ckpt")
-        safetensors_model_path = os.path.join(self.model_path, "model.safetensors")
+
+        local_model_paths = self.find_models(ext_filter=[".ckpt", ".safetensors"])
+        local_ckpt_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("model.ckpt")]), None)
+        local_safetensors_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("model.safetensors")]), None)
+        local_yaml_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("project.yaml")]), None)
+
         if os.path.exists(yaml_path):
             statinfo = os.stat(yaml_path)
             if statinfo.st_size >= 10485760:
                 print("Removing invalid LDSR YAML file.")
                 os.remove(yaml_path)
+
         if os.path.exists(old_model_path):
             print("Renaming model from model.pth to model.ckpt")
             os.rename(old_model_path, new_model_path)
-        if os.path.exists(safetensors_model_path):
-            model = safetensors_model_path
+
+        if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
+            model = local_safetensors_path
         else:
-            model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
-                                       file_name="model.ckpt", progress=True)
-        yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
-                                  file_name="project.yaml", progress=True)
+            model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="model.ckpt", progress=True)
+
+        yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_path, file_name="project.yaml", progress=True)
 
         try:
             return LDSR(model, yaml)

+ 17 - 11
extensions-builtin/LDSR/sd_hijack_autoencoder.py

@@ -1,16 +1,21 @@
 # The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
 # The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
 # As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
-
+import numpy as np
 import torch
 import pytorch_lightning as pl
 import torch.nn.functional as F
 from contextlib import contextmanager
+
+from torch.optim.lr_scheduler import LambdaLR
+
+from ldm.modules.ema import LitEma
 from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
 from ldm.modules.diffusionmodules.model import Encoder, Decoder
 from ldm.util import instantiate_from_config
 
 import ldm.models.autoencoder
+from packaging import version
 
 class VQModel(pl.LightningModule):
     def __init__(self,
@@ -19,7 +24,7 @@ class VQModel(pl.LightningModule):
                  n_embed,
                  embed_dim,
                  ckpt_path=None,
-                 ignore_keys=[],
+                 ignore_keys=None,
                  image_key="image",
                  colorize_nlabels=None,
                  monitor=None,
@@ -57,7 +62,7 @@ class VQModel(pl.LightningModule):
             print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
 
         if ckpt_path is not None:
-            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [])
         self.scheduler_config = scheduler_config
         self.lr_g_factor = lr_g_factor
 
@@ -76,11 +81,11 @@ class VQModel(pl.LightningModule):
                 if context is not None:
                     print(f"{context}: Restored training weights")
 
-    def init_from_ckpt(self, path, ignore_keys=list()):
+    def init_from_ckpt(self, path, ignore_keys=None):
         sd = torch.load(path, map_location="cpu")["state_dict"]
         keys = list(sd.keys())
         for k in keys:
-            for ik in ignore_keys:
+            for ik in ignore_keys or []:
                 if k.startswith(ik):
                     print("Deleting key {} from state_dict.".format(k))
                     del sd[k]
@@ -165,7 +170,7 @@ class VQModel(pl.LightningModule):
     def validation_step(self, batch, batch_idx):
         log_dict = self._validation_step(batch, batch_idx)
         with self.ema_scope():
-            log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
+            self._validation_step(batch, batch_idx, suffix="_ema")
         return log_dict
 
     def _validation_step(self, batch, batch_idx, suffix=""):
@@ -232,7 +237,7 @@ class VQModel(pl.LightningModule):
         return self.decoder.conv_out.weight
 
     def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
-        log = dict()
+        log = {}
         x = self.get_input(batch, self.image_key)
         x = x.to(self.device)
         if only_inputs:
@@ -249,7 +254,8 @@ class VQModel(pl.LightningModule):
         if plot_ema:
             with self.ema_scope():
                 xrec_ema, _ = self(x)
-                if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
+                if x.shape[1] > 3:
+                    xrec_ema = self.to_rgb(xrec_ema)
                 log["reconstructions_ema"] = xrec_ema
         return log
 
@@ -264,7 +270,7 @@ class VQModel(pl.LightningModule):
 
 class VQModelInterface(VQModel):
     def __init__(self, embed_dim, *args, **kwargs):
-        super().__init__(embed_dim=embed_dim, *args, **kwargs)
+        super().__init__(*args, embed_dim=embed_dim, **kwargs)
         self.embed_dim = embed_dim
 
     def encode(self, x):
@@ -282,5 +288,5 @@ class VQModelInterface(VQModel):
         dec = self.decoder(quant)
         return dec
 
-setattr(ldm.models.autoencoder, "VQModel", VQModel)
-setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface)
+ldm.models.autoencoder.VQModel = VQModel
+ldm.models.autoencoder.VQModelInterface = VQModelInterface

+ 30 - 36
extensions-builtin/LDSR/sd_hijack_ddpm_v1.py

@@ -48,7 +48,7 @@ class DDPMV1(pl.LightningModule):
                  beta_schedule="linear",
                  loss_type="l2",
                  ckpt_path=None,
-                 ignore_keys=[],
+                 ignore_keys=None,
                  load_only_unet=False,
                  monitor="val/loss",
                  use_ema=True,
@@ -100,7 +100,7 @@ class DDPMV1(pl.LightningModule):
         if monitor is not None:
             self.monitor = monitor
         if ckpt_path is not None:
-            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
+            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
 
         self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
                                linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
@@ -182,13 +182,13 @@ class DDPMV1(pl.LightningModule):
                 if context is not None:
                     print(f"{context}: Restored training weights")
 
-    def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+    def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
         sd = torch.load(path, map_location="cpu")
         if "state_dict" in list(sd.keys()):
             sd = sd["state_dict"]
         keys = list(sd.keys())
         for k in keys:
-            for ik in ignore_keys:
+            for ik in ignore_keys or []:
                 if k.startswith(ik):
                     print("Deleting key {} from state_dict.".format(k))
                     del sd[k]
@@ -375,7 +375,7 @@ class DDPMV1(pl.LightningModule):
 
     @torch.no_grad()
     def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
-        log = dict()
+        log = {}
         x = self.get_input(batch, self.first_stage_key)
         N = min(x.shape[0], N)
         n_row = min(x.shape[0], n_row)
@@ -383,7 +383,7 @@ class DDPMV1(pl.LightningModule):
         log["inputs"] = x
 
         # get diffusion row
-        diffusion_row = list()
+        diffusion_row = []
         x_start = x[:n_row]
 
         for t in range(self.num_timesteps):
@@ -444,13 +444,13 @@ class LatentDiffusionV1(DDPMV1):
             conditioning_key = None
         ckpt_path = kwargs.pop("ckpt_path", None)
         ignore_keys = kwargs.pop("ignore_keys", [])
-        super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
+        super().__init__(*args, conditioning_key=conditioning_key, **kwargs)
         self.concat_mode = concat_mode
         self.cond_stage_trainable = cond_stage_trainable
         self.cond_stage_key = cond_stage_key
         try:
             self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
-        except:
+        except Exception:
             self.num_downs = 0
         if not scale_by_std:
             self.scale_factor = scale_factor
@@ -460,7 +460,7 @@ class LatentDiffusionV1(DDPMV1):
         self.instantiate_cond_stage(cond_stage_config)
         self.cond_stage_forward = cond_stage_forward
         self.clip_denoised = False
-        self.bbox_tokenizer = None  
+        self.bbox_tokenizer = None
 
         self.restarted_from_ckpt = False
         if ckpt_path is not None:
@@ -792,7 +792,7 @@ class LatentDiffusionV1(DDPMV1):
                 z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1]))  # (bn, nc, ks[0], ks[1], L )
 
                 # 2. apply model loop over last dim
-                if isinstance(self.first_stage_model, VQModelInterface):  
+                if isinstance(self.first_stage_model, VQModelInterface):
                     output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
                                                                  force_not_quantize=predict_cids or force_not_quantize)
                                    for i in range(z.shape[-1])]
@@ -877,16 +877,6 @@ class LatentDiffusionV1(DDPMV1):
                 c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
         return self.p_losses(x, c, t, *args, **kwargs)
 
-    def _rescale_annotations(self, bboxes, crop_coordinates):  # TODO: move to dataset
-        def rescale_bbox(bbox):
-            x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
-            y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
-            w = min(bbox[2] / crop_coordinates[2], 1 - x0)
-            h = min(bbox[3] / crop_coordinates[3], 1 - y0)
-            return x0, y0, w, h
-
-        return [rescale_bbox(b) for b in bboxes]
-
     def apply_model(self, x_noisy, t, cond, return_ids=False):
 
         if isinstance(cond, dict):
@@ -900,7 +890,7 @@ class LatentDiffusionV1(DDPMV1):
 
         if hasattr(self, "split_input_params"):
             assert len(cond) == 1  # todo can only deal with one conditioning atm
-            assert not return_ids  
+            assert not return_ids
             ks = self.split_input_params["ks"]  # eg. (128, 128)
             stride = self.split_input_params["stride"]  # eg. (64, 64)
 
@@ -1126,7 +1116,7 @@ class LatentDiffusionV1(DDPMV1):
         if cond is not None:
             if isinstance(cond, dict):
                 cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
-                list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+                [x[:batch_size] for x in cond[key]] for key in cond}
             else:
                 cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
 
@@ -1157,8 +1147,10 @@ class LatentDiffusionV1(DDPMV1):
 
             if i % log_every_t == 0 or i == timesteps - 1:
                 intermediates.append(x0_partial)
-            if callback: callback(i)
-            if img_callback: img_callback(img, i)
+            if callback:
+                callback(i)
+            if img_callback:
+                img_callback(img, i)
         return img, intermediates
 
     @torch.no_grad()
@@ -1205,8 +1197,10 @@ class LatentDiffusionV1(DDPMV1):
 
             if i % log_every_t == 0 or i == timesteps - 1:
                 intermediates.append(img)
-            if callback: callback(i)
-            if img_callback: img_callback(img, i)
+            if callback:
+                callback(i)
+            if img_callback:
+                img_callback(img, i)
 
         if return_intermediates:
             return img, intermediates
@@ -1221,7 +1215,7 @@ class LatentDiffusionV1(DDPMV1):
         if cond is not None:
             if isinstance(cond, dict):
                 cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
-                list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+                [x[:batch_size] for x in cond[key]] for key in cond}
             else:
                 cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
         return self.p_sample_loop(cond,
@@ -1253,7 +1247,7 @@ class LatentDiffusionV1(DDPMV1):
 
         use_ddim = ddim_steps is not None
 
-        log = dict()
+        log = {}
         z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
                                            return_first_stage_outputs=True,
                                            force_c_encode=True,
@@ -1280,7 +1274,7 @@ class LatentDiffusionV1(DDPMV1):
 
         if plot_diffusion_rows:
             # get diffusion row
-            diffusion_row = list()
+            diffusion_row = []
             z_start = z[:n_row]
             for t in range(self.num_timesteps):
                 if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
@@ -1322,7 +1316,7 @@ class LatentDiffusionV1(DDPMV1):
 
             if inpaint:
                 # make a simple center square
-                b, h, w = z.shape[0], z.shape[2], z.shape[3]
+                h, w = z.shape[2], z.shape[3]
                 mask = torch.ones(N, h, w).to(self.device)
                 # zeros will be filled in
                 mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
@@ -1424,10 +1418,10 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
     # TODO: move all layout-specific hacks to this class
     def __init__(self, cond_stage_key, *args, **kwargs):
         assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
-        super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
+        super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
 
     def log_images(self, batch, N=8, *args, **kwargs):
-        logs = super().log_images(batch=batch, N=N, *args, **kwargs)
+        logs = super().log_images(*args, batch=batch, N=N, **kwargs)
 
         key = 'train' if self.training else 'validation'
         dset = self.trainer.datamodule.datasets[key]
@@ -1443,7 +1437,7 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
         logs['bbox_image'] = cond_img
         return logs
 
-setattr(ldm.models.diffusion.ddpm, "DDPMV1", DDPMV1)
-setattr(ldm.models.diffusion.ddpm, "LatentDiffusionV1", LatentDiffusionV1)
-setattr(ldm.models.diffusion.ddpm, "DiffusionWrapperV1", DiffusionWrapperV1)
-setattr(ldm.models.diffusion.ddpm, "Layout2ImgDiffusionV1", Layout2ImgDiffusionV1)
+ldm.models.diffusion.ddpm.DDPMV1 = DDPMV1
+ldm.models.diffusion.ddpm.LatentDiffusionV1 = LatentDiffusionV1
+ldm.models.diffusion.ddpm.DiffusionWrapperV1 = DiffusionWrapperV1
+ldm.models.diffusion.ddpm.Layout2ImgDiffusionV1 = Layout2ImgDiffusionV1

+ 2 - 1
extensions-builtin/Lora/extra_networks_lora.py

@@ -1,6 +1,7 @@
 from modules import extra_networks, shared
 import lora
 
+
 class ExtraNetworkLora(extra_networks.ExtraNetwork):
     def __init__(self):
         super().__init__('lora')
@@ -8,7 +9,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
     def activate(self, p, params_list):
         additional = shared.opts.sd_lora
 
-        if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
+        if additional != "None" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
             p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
             params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
 

+ 121 - 24
extensions-builtin/Lora/lora.py

@@ -1,10 +1,9 @@
-import glob
 import os
 import re
 import torch
 from typing import Union
 
-from modules import shared, devices, sd_models, errors
+from modules import shared, devices, sd_models, errors, scripts
 
 metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
 
@@ -93,6 +92,7 @@ class LoraOnDisk:
             self.metadata = m
 
         self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None)  # those are cover images and they are too big to display in UI as text
+        self.alias = self.metadata.get('ss_output_name', self.name)
 
 
 class LoraModule:
@@ -132,6 +132,10 @@ def load_lora(name, filename):
 
     sd = sd_models.read_state_dict(filename)
 
+    # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
+    if not hasattr(shared.sd_model, 'lora_layer_mapping'):
+        assign_lora_names_to_compvis_modules(shared.sd_model)
+
     keys_failed_to_match = {}
     is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping
 
@@ -165,12 +169,14 @@ def load_lora(name, filename):
             module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
         elif type(sd_module) == torch.nn.MultiheadAttention:
             module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
-        elif type(sd_module) == torch.nn.Conv2d:
+        elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1):
             module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
+        elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
+            module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
         else:
             print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
             continue
-            assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
+            raise AssertionError(f"Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}")
 
         with torch.no_grad():
             module.weight.copy_(weight)
@@ -182,7 +188,7 @@ def load_lora(name, filename):
         elif lora_key == "lora_down.weight":
             lora_module.down = module
         else:
-            assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
+            raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")
 
     if len(keys_failed_to_match) > 0:
         print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
@@ -199,11 +205,11 @@ def load_loras(names, multipliers=None):
 
     loaded_loras.clear()
 
-    loras_on_disk = [available_loras.get(name, None) for name in names]
-    if any([x is None for x in loras_on_disk]):
+    loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
+    if any(x is None for x in loras_on_disk):
         list_available_loras()
 
-        loras_on_disk = [available_loras.get(name, None) for name in names]
+        loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
 
     for i, name in enumerate(names):
         lora = already_loaded.get(name, None)
@@ -211,7 +217,11 @@ def load_loras(names, multipliers=None):
         lora_on_disk = loras_on_disk[i]
         if lora_on_disk is not None:
             if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
-                lora = load_lora(name, lora_on_disk.filename)
+                try:
+                    lora = load_lora(name, lora_on_disk.filename)
+                except Exception as e:
+                    errors.display(e, f"loading Lora {lora_on_disk.filename}")
+                    continue
 
         if lora is None:
             print(f"Couldn't find Lora with name {name}")
@@ -228,6 +238,8 @@ def lora_calc_updown(lora, module, target):
 
         if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
             updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
+        elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
+            updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
         else:
             updown = up @ down
 
@@ -236,6 +248,19 @@ def lora_calc_updown(lora, module, target):
         return updown
 
 
+def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
+    weights_backup = getattr(self, "lora_weights_backup", None)
+
+    if weights_backup is None:
+        return
+
+    if isinstance(self, torch.nn.MultiheadAttention):
+        self.in_proj_weight.copy_(weights_backup[0])
+        self.out_proj.weight.copy_(weights_backup[1])
+    else:
+        self.weight.copy_(weights_backup)
+
+
 def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
     """
     Applies the currently selected set of Loras to the weights of torch layer self.
@@ -260,12 +285,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
         self.lora_weights_backup = weights_backup
 
     if current_names != wanted_names:
-        if weights_backup is not None:
-            if isinstance(self, torch.nn.MultiheadAttention):
-                self.in_proj_weight.copy_(weights_backup[0])
-                self.out_proj.weight.copy_(weights_backup[1])
-            else:
-                self.weight.copy_(weights_backup)
+        lora_restore_weights_from_backup(self)
 
         for lora in loaded_loras:
             module = lora.modules.get(lora_layer_name, None)
@@ -293,15 +313,48 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
 
             print(f'failed to calculate lora weights for layer {lora_layer_name}')
 
-        setattr(self, "lora_current_names", wanted_names)
+        self.lora_current_names = wanted_names
+
+
+def lora_forward(module, input, original_forward):
+    """
+    Old way of applying Lora by executing operations during layer's forward.
+    Stacking many loras this way results in big performance degradation.
+    """
+
+    if len(loaded_loras) == 0:
+        return original_forward(module, input)
+
+    input = devices.cond_cast_unet(input)
+
+    lora_restore_weights_from_backup(module)
+    lora_reset_cached_weight(module)
+
+    res = original_forward(module, input)
+
+    lora_layer_name = getattr(module, 'lora_layer_name', None)
+    for lora in loaded_loras:
+        module = lora.modules.get(lora_layer_name, None)
+        if module is None:
+            continue
+
+        module.up.to(device=devices.device)
+        module.down.to(device=devices.device)
+
+        res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
+
+    return res
 
 
 def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
-    setattr(self, "lora_current_names", ())
-    setattr(self, "lora_weights_backup", None)
+    self.lora_current_names = ()
+    self.lora_weights_backup = None
 
 
 def lora_Linear_forward(self, input):
+    if shared.opts.lora_functional:
+        return lora_forward(self, input, torch.nn.Linear_forward_before_lora)
+
     lora_apply_weights(self)
 
     return torch.nn.Linear_forward_before_lora(self, input)
@@ -314,6 +367,9 @@ def lora_Linear_load_state_dict(self, *args, **kwargs):
 
 
 def lora_Conv2d_forward(self, input):
+    if shared.opts.lora_functional:
+        return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora)
+
     lora_apply_weights(self)
 
     return torch.nn.Conv2d_forward_before_lora(self, input)
@@ -339,24 +395,65 @@ def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs):
 
 def list_available_loras():
     available_loras.clear()
+    available_lora_aliases.clear()
+    forbidden_lora_aliases.clear()
+    forbidden_lora_aliases.update({"none": 1})
 
     os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
 
-    candidates = \
-        glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \
-        glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
-        glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)
-
+    candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
     for filename in sorted(candidates, key=str.lower):
         if os.path.isdir(filename):
             continue
 
         name = os.path.splitext(os.path.basename(filename))[0]
+        entry = LoraOnDisk(name, filename)
+
+        available_loras[name] = entry
+
+        if entry.alias in available_lora_aliases:
+            forbidden_lora_aliases[entry.alias.lower()] = 1
+
+        available_lora_aliases[name] = entry
+        available_lora_aliases[entry.alias] = entry
+
+
+re_lora_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
+
+
+def infotext_pasted(infotext, params):
+    if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
+        return  # if the other extension is active, it will handle those fields, no need to do anything
+
+    added = []
+
+    for k in params:
+        if not k.startswith("AddNet Model "):
+            continue
+
+        num = k[13:]
+
+        if params.get("AddNet Module " + num) != "LoRA":
+            continue
+
+        name = params.get("AddNet Model " + num)
+        if name is None:
+            continue
+
+        m = re_lora_name.match(name)
+        if m:
+            name = m.group(1)
+
+        multiplier = params.get("AddNet Weight A " + num, "1.0")
 
-        available_loras[name] = LoraOnDisk(name, filename)
+        added.append(f"<lora:{name}:{multiplier}>")
 
+    if added:
+        params["Prompt"] += "\n" + "".join(added)
 
 available_loras = {}
+available_lora_aliases = {}
+forbidden_lora_aliases = {}
 loaded_loras = []
 
 list_available_loras()

+ 28 - 2
extensions-builtin/Lora/scripts/lora_script.py

@@ -1,12 +1,12 @@
 import torch
 import gradio as gr
+from fastapi import FastAPI
 
 import lora
 import extra_networks_lora
 import ui_extra_networks_lora
 from modules import script_callbacks, ui_extra_networks, extra_networks, shared
 
-
 def unload():
     torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
     torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
@@ -49,8 +49,34 @@ torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention
 script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
 script_callbacks.on_script_unloaded(unload)
 script_callbacks.on_before_ui(before_ui)
+script_callbacks.on_infotext_pasted(lora.infotext_pasted)
 
 
 shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
-    "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
+    "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras),
+    "lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
+}))
+
+
+shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {
+    "lora_functional": shared.OptionInfo(False, "Lora: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
 }))
+
+
+def create_lora_json(obj: lora.LoraOnDisk):
+    return {
+        "name": obj.name,
+        "alias": obj.alias,
+        "path": obj.filename,
+        "metadata": obj.metadata,
+    }
+
+
+def api_loras(_: gr.Blocks, app: FastAPI):
+    @app.get("/sdapi/v1/loras")
+    async def get_loras():
+        return [create_lora_json(obj) for obj in lora.available_loras.values()]
+
+
+script_callbacks.on_app_started(api_loras)
+

+ 7 - 1
extensions-builtin/Lora/ui_extra_networks_lora.py

@@ -15,13 +15,19 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
     def list_items(self):
         for name, lora_on_disk in lora.available_loras.items():
             path, ext = os.path.splitext(lora_on_disk.filename)
+
+            if shared.opts.lora_preferred_name == "Filename" or lora_on_disk.alias.lower() in lora.forbidden_lora_aliases:
+                alias = name
+            else:
+                alias = lora_on_disk.alias
+
             yield {
                 "name": name,
                 "filename": path,
                 "preview": self.find_preview(path),
                 "description": self.find_description(path),
                 "search_term": self.search_terms_from_path(lora_on_disk.filename),
-                "prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
+                "prompt": json.dumps(f"<lora:{alias}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
                 "local_preview": f"{path}.{shared.opts.samples_format}",
                 "metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
             }

+ 79 - 16
extensions-builtin/ScuNET/scripts/scunet_model.py

@@ -5,11 +5,14 @@ import traceback
 import PIL.Image
 import numpy as np
 import torch
+from tqdm import tqdm
+
 from basicsr.utils.download_util import load_file_from_url
 
 import modules.upscaler
-from modules import devices, modelloader
+from modules import devices, modelloader, script_callbacks
 from scunet_model_arch import SCUNet as net
+from modules.shared import opts
 
 
 class UpscalerScuNET(modules.upscaler.Upscaler):
@@ -42,28 +45,78 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
             scalers.append(scaler_data2)
         self.scalers = scalers
 
-    def do_upscale(self, img: PIL.Image, selected_file):
+    @staticmethod
+    @torch.no_grad()
+    def tiled_inference(img, model):
+        # test the image tile by tile
+        h, w = img.shape[2:]
+        tile = opts.SCUNET_tile
+        tile_overlap = opts.SCUNET_tile_overlap
+        if tile == 0:
+            return model(img)
+
+        device = devices.get_device_for('scunet')
+        assert tile % 8 == 0, "tile size should be a multiple of window_size"
+        sf = 1
+
+        stride = tile - tile_overlap
+        h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
+        w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
+        E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device)
+        W = torch.zeros_like(E, dtype=devices.dtype, device=device)
+
+        with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar:
+            for h_idx in h_idx_list:
+
+                for w_idx in w_idx_list:
+
+                    in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
+
+                    out_patch = model(in_patch)
+                    out_patch_mask = torch.ones_like(out_patch)
+
+                    E[
+                        ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
+                    ].add_(out_patch)
+                    W[
+                        ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
+                    ].add_(out_patch_mask)
+                    pbar.update(1)
+        output = E.div_(W)
+
+        return output
+
+    def do_upscale(self, img: PIL.Image.Image, selected_file):
+
         torch.cuda.empty_cache()
 
         model = self.load_model(selected_file)
         if model is None:
+            print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr)
             return img
 
         device = devices.get_device_for('scunet')
-        img = np.array(img)
-        img = img[:, :, ::-1]
-        img = np.moveaxis(img, 2, 0) / 255
-        img = torch.from_numpy(img).float()
-        img = img.unsqueeze(0).to(device)
-
-        with torch.no_grad():
-            output = model(img)
-        output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
-        output = 255. * np.moveaxis(output, 0, 2)
-        output = output.astype(np.uint8)
-        output = output[:, :, ::-1]
+        tile = opts.SCUNET_tile
+        h, w = img.height, img.width
+        np_img = np.array(img)
+        np_img = np_img[:, :, ::-1]  # RGB to BGR
+        np_img = np_img.transpose((2, 0, 1)) / 255  # HWC to CHW
+        torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device)  # type: ignore
+
+        if tile > h or tile > w:
+            _img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device)
+            _img[:, :, :h, :w] = torch_img # pad image
+            torch_img = _img
+
+        torch_output = self.tiled_inference(torch_img, model).squeeze(0)
+        torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
+        np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
+        del torch_img, torch_output
         torch.cuda.empty_cache()
-        return PIL.Image.fromarray(output, 'RGB')
+
+        output = np_output.transpose((1, 2, 0))  # CHW to HWC
+        output = output[:, :, ::-1]  # BGR to RGB
+        return PIL.Image.fromarray((output * 255).astype(np.uint8))
 
     def load_model(self, path: str):
         device = devices.get_device_for('scunet')
@@ -79,9 +132,19 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
         model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
         model.load_state_dict(torch.load(filename), strict=True)
         model.eval()
-        for k, v in model.named_parameters():
+        for _, v in model.named_parameters():
             v.requires_grad = False
         model = model.to(device)
 
         return model
 
+
+def on_ui_settings():
+    import gradio as gr
+    from modules import shared
+
+    shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling"))
+    shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam"))
+
+
+script_callbacks.on_ui_settings(on_ui_settings)

+ 7 - 4
extensions-builtin/ScuNET/scunet_model_arch.py

@@ -61,7 +61,9 @@ class WMSA(nn.Module):
         Returns:
             output: tensor shape [b h w c]
         """
-        if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
+        if self.type != 'W':
+            x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
+
         x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
         h_windows = x.size(1)
         w_windows = x.size(2)
@@ -85,8 +87,9 @@ class WMSA(nn.Module):
         output = self.linear(output)
         output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
 
-        if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2),
-                                                 dims=(1, 2))
+        if self.type != 'W':
+            output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2))
+
         return output
 
     def relative_embedding(self):
@@ -262,4 +265,4 @@ class SCUNet(nn.Module):
                 nn.init.constant_(m.bias, 0)
         elif isinstance(m, nn.LayerNorm):
             nn.init.constant_(m.bias, 0)
-            nn.init.constant_(m.weight, 1.0)
+            nn.init.constant_(m.weight, 1.0)

+ 3 - 4
extensions-builtin/SwinIR/scripts/swinir_model.py

@@ -1,4 +1,3 @@
-import contextlib
 import os
 
 import numpy as np
@@ -8,7 +7,7 @@ from basicsr.utils.download_util import load_file_from_url
 from tqdm import tqdm
 
 from modules import modelloader, devices, script_callbacks, shared
-from modules.shared import cmd_opts, opts, state
+from modules.shared import opts, state
 from swinir_model_arch import SwinIR as net
 from swinir_model_arch_v2 import Swin2SR as net2
 from modules.upscaler import Upscaler, UpscalerData
@@ -45,7 +44,7 @@ class UpscalerSwinIR(Upscaler):
         img = upscale(img, model)
         try:
             torch.cuda.empty_cache()
-        except:
+        except Exception:
             pass
         return img
 
@@ -151,7 +150,7 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
             for w_idx in w_idx_list:
                 if state.interrupted or state.skipped:
                     break
-                
+
                 in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
                 out_patch = model(in_patch)
                 out_patch_mask = torch.ones_like(out_patch)

+ 3 - 3
extensions-builtin/SwinIR/swinir_model_arch.py

@@ -644,7 +644,7 @@ class SwinIR(nn.Module):
     """
 
     def __init__(self, img_size=64, patch_size=1, in_chans=3,
-                 embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
+                 embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
                  window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                  drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                  norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
@@ -805,7 +805,7 @@ class SwinIR(nn.Module):
     def forward(self, x):
         H, W = x.shape[2:]
         x = self.check_image_size(x)
-        
+
         self.mean = self.mean.type_as(x)
         x = (x - self.mean) * self.img_range
 
@@ -844,7 +844,7 @@ class SwinIR(nn.Module):
         H, W = self.patches_resolution
         flops += H * W * 3 * self.embed_dim * 9
         flops += self.patch_embed.flops()
-        for i, layer in enumerate(self.layers):
+        for layer in self.layers:
             flops += layer.flops()
         flops += H * W * 3 * self.embed_dim * self.embed_dim
         flops += self.upsample.flops()

+ 29 - 29
extensions-builtin/SwinIR/swinir_model_arch_v2.py

@@ -74,7 +74,7 @@ class WindowAttention(nn.Module):
     """
 
     def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
-                 pretrained_window_size=[0, 0]):
+                 pretrained_window_size=(0, 0)):
 
         super().__init__()
         self.dim = dim
@@ -241,7 +241,7 @@ class SwinTransformerBlock(nn.Module):
             attn_mask = None
 
         self.register_buffer("attn_mask", attn_mask)
-        
+
     def calculate_mask(self, x_size):
         # calculate attention mask for SW-MSA
         H, W = x_size
@@ -263,7 +263,7 @@ class SwinTransformerBlock(nn.Module):
         attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
         attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
 
-        return attn_mask        
+        return attn_mask
 
     def forward(self, x, x_size):
         H, W = x_size
@@ -288,7 +288,7 @@ class SwinTransformerBlock(nn.Module):
             attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C
         else:
             attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
-            
+
         # merge windows
         attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
         shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
@@ -369,7 +369,7 @@ class PatchMerging(nn.Module):
         H, W = self.input_resolution
         flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
         flops += H * W * self.dim // 2
-        return flops    
+        return flops
 
 class BasicLayer(nn.Module):
     """ A basic Swin Transformer layer for one stage.
@@ -447,7 +447,7 @@ class BasicLayer(nn.Module):
             nn.init.constant_(blk.norm1.weight, 0)
             nn.init.constant_(blk.norm2.bias, 0)
             nn.init.constant_(blk.norm2.weight, 0)
-            
+
 class PatchEmbed(nn.Module):
     r""" Image to Patch Embedding
     Args:
@@ -492,7 +492,7 @@ class PatchEmbed(nn.Module):
         flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
         if self.norm is not None:
             flops += Ho * Wo * self.embed_dim
-        return flops           
+        return flops
 
 class RSTB(nn.Module):
     """Residual Swin Transformer Block (RSTB).
@@ -531,7 +531,7 @@ class RSTB(nn.Module):
                                          num_heads=num_heads,
                                          window_size=window_size,
                                          mlp_ratio=mlp_ratio,
-                                         qkv_bias=qkv_bias, 
+                                         qkv_bias=qkv_bias,
                                          drop=drop, attn_drop=attn_drop,
                                          drop_path=drop_path,
                                          norm_layer=norm_layer,
@@ -622,7 +622,7 @@ class Upsample(nn.Sequential):
         else:
             raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
         super(Upsample, self).__init__(*m)
-        
+
 class Upsample_hf(nn.Sequential):
     """Upsample module.
 
@@ -642,7 +642,7 @@ class Upsample_hf(nn.Sequential):
             m.append(nn.PixelShuffle(3))
         else:
             raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
-        super(Upsample_hf, self).__init__(*m)        
+        super(Upsample_hf, self).__init__(*m)
 
 
 class UpsampleOneStep(nn.Sequential):
@@ -667,8 +667,8 @@ class UpsampleOneStep(nn.Sequential):
         H, W = self.input_resolution
         flops = H * W * self.num_feat * 3 * 9
         return flops
-    
-    
+
+
 
 class Swin2SR(nn.Module):
     r""" Swin2SR
@@ -698,8 +698,8 @@ class Swin2SR(nn.Module):
     """
 
     def __init__(self, img_size=64, patch_size=1, in_chans=3,
-                 embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
-                 window_size=7, mlp_ratio=4., qkv_bias=True, 
+                 embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
+                 window_size=7, mlp_ratio=4., qkv_bias=True,
                  drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                  norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                  use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
@@ -764,7 +764,7 @@ class Swin2SR(nn.Module):
                          num_heads=num_heads[i_layer],
                          window_size=window_size,
                          mlp_ratio=self.mlp_ratio,
-                         qkv_bias=qkv_bias, 
+                         qkv_bias=qkv_bias,
                          drop=drop_rate, attn_drop=attn_drop_rate,
                          drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],  # no impact on SR results
                          norm_layer=norm_layer,
@@ -776,7 +776,7 @@ class Swin2SR(nn.Module):
 
                          )
             self.layers.append(layer)
-            
+
         if self.upsampler == 'pixelshuffle_hf':
             self.layers_hf = nn.ModuleList()
             for i_layer in range(self.num_layers):
@@ -787,7 +787,7 @@ class Swin2SR(nn.Module):
                              num_heads=num_heads[i_layer],
                              window_size=window_size,
                              mlp_ratio=self.mlp_ratio,
-                             qkv_bias=qkv_bias, 
+                             qkv_bias=qkv_bias,
                              drop=drop_rate, attn_drop=attn_drop_rate,
                              drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],  # no impact on SR results
                              norm_layer=norm_layer,
@@ -799,7 +799,7 @@ class Swin2SR(nn.Module):
 
                              )
                 self.layers_hf.append(layer)
-                        
+
         self.norm = norm_layer(self.num_features)
 
         # build the last conv layer in deep feature extraction
@@ -829,10 +829,10 @@ class Swin2SR(nn.Module):
             self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
             self.conv_after_aux = nn.Sequential(
                 nn.Conv2d(3, num_feat, 3, 1, 1),
-                nn.LeakyReLU(inplace=True))            
+                nn.LeakyReLU(inplace=True))
             self.upsample = Upsample(upscale, num_feat)
             self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
-            
+
         elif self.upsampler == 'pixelshuffle_hf':
             self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
                                                       nn.LeakyReLU(inplace=True))
@@ -846,7 +846,7 @@ class Swin2SR(nn.Module):
                 nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
                 nn.LeakyReLU(inplace=True))
             self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
-            
+
         elif self.upsampler == 'pixelshuffledirect':
             # for lightweight SR (to save parameters)
             self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
@@ -905,7 +905,7 @@ class Swin2SR(nn.Module):
         x = self.patch_unembed(x, x_size)
 
         return x
-    
+
     def forward_features_hf(self, x):
         x_size = (x.shape[2], x.shape[3])
         x = self.patch_embed(x)
@@ -919,7 +919,7 @@ class Swin2SR(nn.Module):
         x = self.norm(x)  # B L C
         x = self.patch_unembed(x, x_size)
 
-        return x    
+        return x
 
     def forward(self, x):
         H, W = x.shape[2:]
@@ -951,7 +951,7 @@ class Swin2SR(nn.Module):
             x = self.conv_after_body(self.forward_features(x)) + x
             x_before = self.conv_before_upsample(x)
             x_out = self.conv_last(self.upsample(x_before))
-            
+
             x_hf = self.conv_first_hf(x_before)
             x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
             x_hf = self.conv_before_upsample_hf(x_hf)
@@ -977,15 +977,15 @@ class Swin2SR(nn.Module):
             x_first = self.conv_first(x)
             res = self.conv_after_body(self.forward_features(x_first)) + x_first
             x = x + self.conv_last(res)
-        
+
         x = x / self.img_range + self.mean
         if self.upsampler == "pixelshuffle_aux":
             return x[:, :, :H*self.upscale, :W*self.upscale], aux
-        
+
         elif self.upsampler == "pixelshuffle_hf":
             x_out = x_out / self.img_range + self.mean
             return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
-        
+
         else:
             return x[:, :, :H*self.upscale, :W*self.upscale]
 
@@ -994,7 +994,7 @@ class Swin2SR(nn.Module):
         H, W = self.patches_resolution
         flops += H * W * 3 * self.embed_dim * 9
         flops += self.patch_embed.flops()
-        for i, layer in enumerate(self.layers):
+        for layer in self.layers:
             flops += layer.flops()
         flops += H * W * 3 * self.embed_dim * self.embed_dim
         flops += self.upsample.flops()
@@ -1014,4 +1014,4 @@ if __name__ == '__main__':
 
     x = torch.randn((1, 3, height, width))
     x = model(x)
-    print(x.shape)
+    print(x.shape)

+ 30 - 91
extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js

@@ -1,103 +1,42 @@
 // Stable Diffusion WebUI - Bracket checker
-// Version 1.0
-// By Hingashi no Florin/Bwin4L
+// By Hingashi no Florin/Bwin4L & @akx
 // Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
 // If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
 
-function checkBrackets(evt, textArea, counterElt) {
-  errorStringParen  = '(...) - Different number of opening and closing parentheses detected.\n';
-  errorStringSquare = '[...] - Different number of opening and closing square brackets detected.\n';
-  errorStringCurly  = '{...} - Different number of opening and closing curly brackets detected.\n';
-
-  openBracketRegExp = /\(/g;
-  closeBracketRegExp = /\)/g;
-
-  openSquareBracketRegExp = /\[/g;
-  closeSquareBracketRegExp = /\]/g;
-
-  openCurlyBracketRegExp = /\{/g;
-  closeCurlyBracketRegExp = /\}/g;
-
-  totalOpenBracketMatches = 0;
-  totalCloseBracketMatches = 0;
-  totalOpenSquareBracketMatches = 0;
-  totalCloseSquareBracketMatches = 0;
-  totalOpenCurlyBracketMatches = 0;
-  totalCloseCurlyBracketMatches = 0;
-
-  openBracketMatches = textArea.value.match(openBracketRegExp);
-  if(openBracketMatches) {
-    totalOpenBracketMatches = openBracketMatches.length;
-  }
-
-  closeBracketMatches = textArea.value.match(closeBracketRegExp);
-  if(closeBracketMatches) {
-    totalCloseBracketMatches = closeBracketMatches.length;
-  }
-
-  openSquareBracketMatches = textArea.value.match(openSquareBracketRegExp);
-  if(openSquareBracketMatches) {
-    totalOpenSquareBracketMatches = openSquareBracketMatches.length;
-  }
-
-  closeSquareBracketMatches = textArea.value.match(closeSquareBracketRegExp);
-  if(closeSquareBracketMatches) {
-    totalCloseSquareBracketMatches = closeSquareBracketMatches.length;
-  }
-
-  openCurlyBracketMatches = textArea.value.match(openCurlyBracketRegExp);
-  if(openCurlyBracketMatches) {
-    totalOpenCurlyBracketMatches = openCurlyBracketMatches.length;
-  }
-
-  closeCurlyBracketMatches = textArea.value.match(closeCurlyBracketRegExp);
-  if(closeCurlyBracketMatches) {
-    totalCloseCurlyBracketMatches = closeCurlyBracketMatches.length;
-  }
-
-  if(totalOpenBracketMatches != totalCloseBracketMatches) {
-    if(!counterElt.title.includes(errorStringParen)) {
-      counterElt.title += errorStringParen;
-    }
-  } else {
-    counterElt.title = counterElt.title.replace(errorStringParen, '');
-  }
-
-  if(totalOpenSquareBracketMatches != totalCloseSquareBracketMatches) {
-    if(!counterElt.title.includes(errorStringSquare)) {
-      counterElt.title += errorStringSquare;
-    }
-  } else {
-    counterElt.title = counterElt.title.replace(errorStringSquare, '');
-  }
-
-  if(totalOpenCurlyBracketMatches != totalCloseCurlyBracketMatches) {
-    if(!counterElt.title.includes(errorStringCurly)) {
-      counterElt.title += errorStringCurly;
+function checkBrackets(textArea, counterElt) {
+  var counts = {};
+  (textArea.value.match(/[(){}\[\]]/g) || []).forEach(bracket => {
+    counts[bracket] = (counts[bracket] || 0) + 1;
+  });
+  var errors = [];
+
+  function checkPair(open, close, kind) {
+    if (counts[open] !== counts[close]) {
+      errors.push(
+        `${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.`
+      );
     }
-  } else {
-    counterElt.title = counterElt.title.replace(errorStringCurly, '');
   }
 
-  if(counterElt.title != '') {
-    counterElt.classList.add('error');
-  } else {
-    counterElt.classList.remove('error');
-  }
+  checkPair('(', ')', 'round brackets');
+  checkPair('[', ']', 'square brackets');
+  checkPair('{', '}', 'curly brackets');
+  counterElt.title = errors.join('\n');
+  counterElt.classList.toggle('error', errors.length !== 0);
 }
 
-function setupBracketChecking(id_prompt, id_counter){
-    var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
-    var counter = gradioApp().getElementById(id_counter)
+function setupBracketChecking(id_prompt, id_counter) {
+  var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
+  var counter = gradioApp().getElementById(id_counter)
 
-    textarea.addEventListener("input", function(evt){
-        checkBrackets(evt, textarea, counter)
-    });
+  if (textarea && counter) {
+    textarea.addEventListener("input", () => checkBrackets(textarea, counter));
+  }
 }
 
-onUiLoaded(function(){
-    setupBracketChecking('txt2img_prompt', 'txt2img_token_counter')
-    setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter')
-    setupBracketChecking('img2img_prompt', 'img2img_token_counter')
-    setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter')
-})
+onUiLoaded(function () {
+  setupBracketChecking('txt2img_prompt', 'txt2img_token_counter');
+  setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter');
+  setupBracketChecking('img2img_prompt', 'img2img_token_counter');
+  setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter');
+});

+ 1 - 1
html/extra-networks-card.html

@@ -6,7 +6,7 @@
 			<ul>
 				<a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
 			</ul>
-			<span style="display:none" class='search_term'>{search_term}</span>
+			<span style="display:none" class='search_term{search_only}'>{search_term}</span>
 		</div>
 		<span class='name'>{name}</span>
 		<span class='description'>{description}</span>

+ 26 - 0
html/licenses.html

@@ -661,4 +661,30 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 THE SOFTWARE.
+</pre>
+
+<h2><a href="https://github.com/madebyollin/taesd/blob/main/LICENSE">TAESD</a></h2>
+<small>Tiny AutoEncoder for Stable Diffusion option for live previews</small>
+<pre>
+MIT License
+
+Copyright (c) 2023 Ollin Boer Bohan
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
 </pre>

+ 14 - 19
javascript/aspectRatioOverlay.js

@@ -45,29 +45,24 @@ function dimensionChange(e, is_width, is_height){
 
 		var viewportOffset = targetElement.getBoundingClientRect();
 
-		viewportscale = Math.min(  targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
+		var viewportscale = Math.min(  targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
 
-		scaledx = targetElement.naturalWidth*viewportscale
-		scaledy = targetElement.naturalHeight*viewportscale
+		var scaledx = targetElement.naturalWidth*viewportscale
+		var scaledy = targetElement.naturalHeight*viewportscale
 
-		cleintRectTop    = (viewportOffset.top+window.scrollY)
-		cleintRectLeft   = (viewportOffset.left+window.scrollX)
-		cleintRectCentreY = cleintRectTop  + (targetElement.clientHeight/2)
-		cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
+		var cleintRectTop    = (viewportOffset.top+window.scrollY)
+		var cleintRectLeft   = (viewportOffset.left+window.scrollX)
+		var cleintRectCentreY = cleintRectTop  + (targetElement.clientHeight/2)
+		var cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
 
-		viewRectTop    = cleintRectCentreY-(scaledy/2)
-		viewRectLeft   = cleintRectCentreX-(scaledx/2)
-		arRectWidth  = scaledx
-		arRectHeight = scaledy
+		var arscale = Math.min(  scaledx/currentWidth, scaledy/currentHeight )
+		var arscaledx = currentWidth*arscale
+		var arscaledy = currentHeight*arscale
 
-		arscale = Math.min(  arRectWidth/currentWidth, arRectHeight/currentHeight )
-		arscaledx = currentWidth*arscale
-		arscaledy = currentHeight*arscale
-
-		arRectTop    = cleintRectCentreY-(arscaledy/2)
-		arRectLeft   = cleintRectCentreX-(arscaledx/2)
-		arRectWidth  = arscaledx
-		arRectHeight = arscaledy
+		var arRectTop    = cleintRectCentreY-(arscaledy/2)
+		var arRectLeft   = cleintRectCentreX-(arscaledx/2)
+		var arRectWidth  = arscaledx
+		var arRectHeight = arscaledy
 
 	    arPreviewRect.style.top  = arRectTop+'px';
 	    arPreviewRect.style.left = arRectLeft+'px';

+ 6 - 17
javascript/contextMenus.js

@@ -4,7 +4,7 @@ contextMenuInit = function(){
   let menuSpecs = new Map();
 
   const uid = function(){
-    return Date.now().toString(36) + Math.random().toString(36).substr(2);
+    return Date.now().toString(36) + Math.random().toString(36).substring(2);
   }
 
   function showContextMenu(event,element,menuEntries){
@@ -16,8 +16,7 @@ contextMenuInit = function(){
       oldMenu.remove()
     }
 
-    let tabButton = uiCurrentTab
-    let baseStyle = window.getComputedStyle(tabButton)
+    let baseStyle = window.getComputedStyle(uiCurrentTab)
 
     const contextMenu = document.createElement('nav')
     contextMenu.id = "context-menu"
@@ -36,7 +35,7 @@ contextMenuInit = function(){
     menuEntries.forEach(function(entry){
       let contextMenuEntry = document.createElement('a')
       contextMenuEntry.innerHTML = entry['name']
-      contextMenuEntry.addEventListener("click", function(e) {
+      contextMenuEntry.addEventListener("click", function() {
         entry['func']();
       })
       contextMenuList.append(contextMenuEntry);
@@ -63,7 +62,7 @@ contextMenuInit = function(){
 
   function appendContextMenuOption(targetElementSelector,entryName,entryFunction){
 
-    currentItems = menuSpecs.get(targetElementSelector)
+    var currentItems = menuSpecs.get(targetElementSelector)
 
     if(!currentItems){
       currentItems = []
@@ -79,7 +78,7 @@ contextMenuInit = function(){
   }
 
   function removeContextMenuOption(uid){
-    menuSpecs.forEach(function(v,k) {
+    menuSpecs.forEach(function(v) {
       let index = -1
       v.forEach(function(e,ei){if(e['id']==uid){index=ei}})
       if(index>=0){
@@ -93,8 +92,7 @@ contextMenuInit = function(){
       return;
     }
     gradioApp().addEventListener("click", function(e) {
-      let source = e.composedPath()[0]
-      if(source.id && source.id.indexOf('check_progress')>-1){
+      if(! e.isTrusted){
         return
       }
 
@@ -112,7 +110,6 @@ contextMenuInit = function(){
         if(e.composedPath()[0].matches(k)){
           showContextMenu(e,e.composedPath()[0],v)
           e.preventDefault()
-          return
         }
       })
     });
@@ -161,14 +158,6 @@ addContextMenuEventListener = initResponse[2];
   appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever)
   appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)
 
-  appendContextMenuOption('#roll','Roll three',
-    function(){
-      let rollbutton = get_uiCurrentTabContent().querySelector('#roll');
-      setTimeout(function(){rollbutton.click()},100)
-      setTimeout(function(){rollbutton.click()},200)
-      setTimeout(function(){rollbutton.click()},300)
-    }
-  )
 })();
 //End example Context Menu Items
 

+ 35 - 11
javascript/edit-attention.js

@@ -17,7 +17,7 @@ function keyupEditAttention(event){
 		// Find opening parenthesis around current cursor
 		const before = text.substring(0, selectionStart);
 		let beforeParen = before.lastIndexOf(OPEN);
-		if (beforeParen == -1) return  false;
+		if (beforeParen == -1) return false;
 		let beforeParenClose = before.lastIndexOf(CLOSE);
 		while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
 			beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
@@ -27,7 +27,7 @@ function keyupEditAttention(event){
 		// Find closing parenthesis around current cursor
 		const after = text.substring(selectionStart);
 		let afterParen = after.indexOf(CLOSE);
-		if (afterParen == -1) return  false;
+		if (afterParen == -1) return false;
 		let afterParenOpen = after.indexOf(OPEN);
 		while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
 			afterParen = after.indexOf(CLOSE, afterParen + 1);
@@ -43,16 +43,34 @@ function keyupEditAttention(event){
 		target.setSelectionRange(selectionStart, selectionEnd);
 		return true;
     }
+    
+    function selectCurrentWord(){
+        if (selectionStart !== selectionEnd) return false;
+        const delimiters = opts.keyedit_delimiters + " \r\n\t";
+        
+        // seek backward until to find beggining
+        while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) {
+            selectionStart--;
+        }
+        
+        // seek forward to find end
+        while (!delimiters.includes(text[selectionEnd]) && selectionEnd < text.length) {
+            selectionEnd++;
+        }
 
-	// If the user hasn't selected anything, let's select their current parenthesis block
-    if(! selectCurrentParenthesisBlock('<', '>')){
-        selectCurrentParenthesisBlock('(', ')')
+        target.setSelectionRange(selectionStart, selectionEnd);
+        return true;
+    }
+
+	// If the user hasn't selected anything, let's select their current parenthesis block or word
+    if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) {
+        selectCurrentWord();
     }
 
 	event.preventDefault();
 
-    closeCharacter = ')'
-    delta = opts.keyedit_precision_attention
+    var closeCharacter = ')'
+    var delta = opts.keyedit_precision_attention
 
     if (selectionStart > 0 && text[selectionStart - 1] == '<'){
         closeCharacter = '>'
@@ -73,15 +91,21 @@ function keyupEditAttention(event){
         selectionEnd += 1;
     }
 
-	end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
-	weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
+	var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
+	var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
 	if (isNaN(weight)) return;
 
 	weight += isPlus ? delta : -delta;
 	weight = parseFloat(weight.toPrecision(12));
 	if(String(weight).length == 1) weight += ".0"
 
-	text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
+    if (closeCharacter == ')' && weight == 1) {
+        text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5);
+        selectionStart--;
+        selectionEnd--;
+    } else {
+        text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
+    }
 
 	target.focus();
 	target.value = text;
@@ -93,4 +117,4 @@ function keyupEditAttention(event){
 
 addEventListener('keydown', (event) => {
     keyupEditAttention(event);
-});
+});

+ 28 - 6
javascript/extensions.js

@@ -1,14 +1,14 @@
 
-function extensions_apply(_, _, disable_all){
+function extensions_apply(_disabled_list, _update_list, disable_all){
     var disable = []
     var update = []
 
     gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
         if(x.name.startsWith("enable_") && ! x.checked)
-            disable.push(x.name.substr(7))
+            disable.push(x.name.substring(7))
 
         if(x.name.startsWith("update_") && x.checked)
-            update.push(x.name.substr(7))
+            update.push(x.name.substring(7))
     })
 
     restart_reload()
@@ -16,12 +16,12 @@ function extensions_apply(_, _, disable_all){
     return [JSON.stringify(disable), JSON.stringify(update), disable_all]
 }
 
-function extensions_check(_, _){
+function extensions_check(){
     var disable = []
 
     gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
         if(x.name.startsWith("enable_") && ! x.checked)
-            disable.push(x.name.substr(7))
+            disable.push(x.name.substring(7))
     })
 
     gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
@@ -41,9 +41,31 @@ function install_extension_from_index(button, url){
     button.disabled = "disabled"
     button.value = "Installing..."
 
-    textarea = gradioApp().querySelector('#extension_to_install textarea')
+    var textarea = gradioApp().querySelector('#extension_to_install textarea')
     textarea.value = url
     updateInput(textarea)
 
     gradioApp().querySelector('#install_extension_button').click()
 }
+
+function config_state_confirm_restore(_, config_state_name, config_restore_type) {
+    if (config_state_name == "Current") {
+        return [false, config_state_name, config_restore_type];
+    }
+    let restored = "";
+    if (config_restore_type == "extensions") {
+        restored = "all saved extension versions";
+    } else if (config_restore_type == "webui") {
+        restored = "the webui version";
+    } else {
+        restored = "the webui version and all saved extension versions";
+    }
+    let confirmed = confirm("Are you sure you want to restore from this state?\nThis will reset " + restored + ".");
+    if (confirmed) {
+        restart_reload();
+        gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
+            x.innerHTML = "Loading..."
+        })
+    }
+    return [confirmed, config_state_name, config_restore_type];
+}

+ 48 - 22
javascript/extraNetworks.js

@@ -1,4 +1,3 @@
-
 function setupExtraNetworksForTab(tabname){
     gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
 
@@ -10,16 +9,34 @@ function setupExtraNetworksForTab(tabname){
     tabs.appendChild(search)
     tabs.appendChild(refresh)
 
-    search.addEventListener("input", function(evt){
-        searchTerm = search.value.toLowerCase()
+    var applyFilter = function(){
+        var searchTerm = search.value.toLowerCase()
 
         gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
-            text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
-            elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : ""
+            var searchOnly = elem.querySelector('.search_only')
+            var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
+
+            var visible = text.indexOf(searchTerm) != -1
+
+            if(searchOnly && searchTerm.length < 4){
+                visible = false
+            }
+
+            elem.style.display = visible ? "" : "none"
         })
-    });
+    }
+
+    search.addEventListener("input", applyFilter);
+    applyFilter();
+
+    extraNetworksApplyFilter[tabname] = applyFilter;
 }
 
+function applyExtraNetworkFilter(tabname){
+    setTimeout(extraNetworksApplyFilter[tabname], 1);
+}
+
+var extraNetworksApplyFilter = {}
 var activePromptTextarea = {};
 
 function setupExtraNetworks(){
@@ -51,18 +68,27 @@ var re_extranet_g = /\s+<([^:]+:[^:]+):[\d\.]+>/g;
 
 function tryToRemoveExtraNetworkFromPrompt(textarea, text){
     var m = text.match(re_extranet)
-    if(! m) return false
-
-    var partToSearch = m[1]
     var replaced = false
-    var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, index){
-        m = found.match(re_extranet);
-        if(m[1] == partToSearch){
-            replaced = true;
-            return ""
-        }
-        return found;
-    })
+    var newTextareaText
+    if(m) {
+        var partToSearch = m[1]
+        newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found){
+            m = found.match(re_extranet);
+            if(m[1] == partToSearch){
+                replaced = true;
+                return ""
+            }
+            return found;
+        })
+    } else {
+        newTextareaText = textarea.value.replaceAll(new RegExp(text, "g"), function(found){
+            if(found == text) {
+                replaced = true;
+                return ""
+            }
+            return found;
+        })
+    }
 
     if(replaced){
         textarea.value = newTextareaText
@@ -96,9 +122,9 @@ function saveCardPreview(event, tabname, filename){
 }
 
 function extraNetworksSearchButton(tabs_id, event){
-    searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
-    button = event.target
-    text = button.classList.contains("search-all") ? "" : button.textContent.trim()
+    var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
+    var button = event.target
+    var text = button.classList.contains("search-all") ? "" : button.textContent.trim()
 
     searchTextarea.value = text
     updateInput(searchTextarea)
@@ -133,7 +159,7 @@ function popup(contents){
 }
 
 function extraNetworksShowMetadata(text){
-    elem = document.createElement('pre')
+    var elem = document.createElement('pre')
     elem.classList.add('popup-metadata');
     elem.textContent = text;
 
@@ -165,7 +191,7 @@ function requestGet(url, data, handler, errorHandler){
 }
 
 function extraNetworksRequestMetadata(event, extraPage, cardName){
-    showError = function(){ extraNetworksShowMetadata("there was an error getting metadata"); }
+    var showError = function(){ extraNetworksShowMetadata("there was an error getting metadata"); }
 
     requestGet("./sd_extra_networks/metadata", {"page": extraPage, "item": cardName}, function(data){
         if(data && data.metadata){

+ 4 - 4
javascript/generationParams.js

@@ -16,14 +16,14 @@ onUiUpdate(function(){
 
 let modalObserver = new MutationObserver(function(mutations) {
 	mutations.forEach(function(mutationRecord) {
-		let selectedTab = gradioApp().querySelector('#tabs div button.bg-white')?.innerText
-		if (mutationRecord.target.style.display === 'none' && selectedTab === 'txt2img' || selectedTab === 'img2img')
-			gradioApp().getElementById(selectedTab+"_generation_info_button").click()
+		let selectedTab = gradioApp().querySelector('#tabs div button.selected')?.innerText
+		if (mutationRecord.target.style.display === 'none' && (selectedTab === 'txt2img' || selectedTab === 'img2img'))
+			gradioApp().getElementById(selectedTab+"_generation_info_button")?.click()
 	});
 });
 
 function attachGalleryListeners(tab_name) {
-	gallery = gradioApp().querySelector('#'+tab_name+'_gallery')
+	var gallery = gradioApp().querySelector('#'+tab_name+'_gallery')
 	gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click());
 	gallery?.addEventListener('keydown', (e) => {
 		if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow

+ 46 - 26
javascript/hints.js

@@ -22,6 +22,7 @@ titles = {
     "\u{1f4cb}": "Apply selected styles to current prompt",
     "\u{1f4d2}": "Paste available values into the field",
     "\u{1f3b4}": "Show/hide extra networks",
+    "\u{1f300}": "Restore progress",
 
     "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
     "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
@@ -65,8 +66,8 @@ titles = {
 
     "Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
 
-    "Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
-    "Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg],[prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
+    "Images filename pattern": "Use tags like [seed] and [date] to define how filenames for images are chosen. Leave empty for default.",
+    "Directory name pattern": "Use tags like [seed] and [date] to define how subdirectories for images and grids are chosen. Leave empty for default.",
     "Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
 
     "Loopback": "Performs img2img processing multiple times. Output images are used as input for the next loop.",
@@ -85,7 +86,6 @@ titles = {
     "vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).",
 
     "Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
-    "Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.",
 
     "Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.",
     "Filename join string": "This string will be used to join split words into a single line if the option above is enabled.",
@@ -111,37 +111,57 @@ titles = {
     "Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.",
     "Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.",
     "Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.",
-    "Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited."
+    "Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited.",
+    "Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction."
 }
 
+function updateTooltipForSpan(span){
+    if (span.title) return;  // already has a title
 
-onUiUpdate(function(){
-	gradioApp().querySelectorAll('span, button, select, p').forEach(function(span){
-		tooltip = titles[span.textContent];
+    let tooltip = localization[titles[span.textContent]] || titles[span.textContent];
 
-		if(!tooltip){
-		    tooltip = titles[span.value];
-		}
+    if(!tooltip){
+        tooltip = localization[titles[span.value]] || titles[span.value];
+    }
 
-		if(!tooltip){
-			for (const c of span.classList) {
-				if (c in titles) {
-					tooltip = titles[c];
-					break;
-				}
+    if(!tooltip){
+		for (const c of span.classList) {
+			if (c in titles) {
+				tooltip = localization[titles[c]] || titles[c];
+				break;
 			}
 		}
+	}
 
-		if(tooltip){
-			span.title = tooltip;
-		}
-	})
+	if(tooltip){
+		span.title = tooltip;
+	}
+}
 
-	gradioApp().querySelectorAll('select').forEach(function(select){
-	    if (select.onchange != null) return;
+function updateTooltipForSelect(select){
+    if (select.onchange != null) return;
+
+    select.onchange = function(){
+        select.title = localization[titles[select.value]] || titles[select.value] || "";
+    }
+}
 
-	    select.onchange = function(){
-            select.title = titles[select.value] || "";
-	    }
-	})
+observedTooltipElements = {"SPAN": 1, "BUTTON": 1, "SELECT": 1, "P": 1}
+
+onUiUpdate(function(m){
+    m.forEach(function(record){
+        record.addedNodes.forEach(function(node){
+            if(observedTooltipElements[node.tagName]){
+                updateTooltipForSpan(node)
+            }
+            if(node.tagName == "SELECT"){
+                updateTooltipForSelect(node)
+            }
+
+            if(node.querySelectorAll){
+                node.querySelectorAll('span, button, select, p').forEach(updateTooltipForSpan)
+    	        node.querySelectorAll('select').forEach(updateTooltipForSelect)
+            }
+        })
+    })
 })

+ 6 - 10
javascript/hires_fix.js

@@ -1,16 +1,12 @@
 
-function setInactive(elem, inactive){
-    if(inactive){
-        elem.classList.add('inactive')
-    } else{
-        elem.classList.remove('inactive')
+function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
+    function setInactive(elem, inactive){
+        elem.classList.toggle('inactive', !!inactive)
     }
-}
 
-function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
-    hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
-    hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
-    hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
+    var hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
+    var hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
+    var hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
 
     gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""
 

+ 6 - 7
javascript/imageMaskFix.js

@@ -2,11 +2,10 @@
  * temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668
  * @see https://github.com/gradio-app/gradio/issues/1721
  */
-window.addEventListener( 'resize', () => imageMaskResize());
 function imageMaskResize() {
     const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas');
     if ( ! canvases.length ) {
-    canvases_fixed = false;
+    canvases_fixed = false; // TODO: this is unused..?
     window.removeEventListener( 'resize', imageMaskResize );
     return;
     }
@@ -15,7 +14,7 @@ function imageMaskResize() {
     const previewImage = wrapper.previousElementSibling;
 
     if ( ! previewImage.complete ) {
-        previewImage.addEventListener( 'load', () => imageMaskResize());
+        previewImage.addEventListener( 'load', imageMaskResize);
         return;
     }
 
@@ -24,7 +23,6 @@ function imageMaskResize() {
     const nw = previewImage.naturalWidth;
     const nh = previewImage.naturalHeight;
     const portrait = nh > nw;
-    const factor = portrait;
 
     const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw);
     const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh);
@@ -40,6 +38,7 @@ function imageMaskResize() {
         c.style.maxHeight = '100%';
         c.style.objectFit = 'contain';
     });
- }
-  
- onUiUpdate(() => imageMaskResize());
+}
+
+onUiUpdate(imageMaskResize);
+window.addEventListener( 'resize', imageMaskResize);

+ 0 - 1
javascript/imageParams.js

@@ -1,7 +1,6 @@
 window.onload = (function(){
     window.addEventListener('drop', e => {
         const target = e.composedPath()[0];
-        const idx = selected_gallery_index();
         if (target.placeholder.indexOf("Prompt") == -1) return;
 
         let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";

+ 9 - 10
javascript/imageviewer.js

@@ -57,7 +57,7 @@ function modalImageSwitch(offset) {
         })
 
         if (result != -1) {
-            nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)]
+            var nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)]
             nextButton.click()
             const modalImage = gradioApp().getElementById("modalImage");
             const modal = gradioApp().getElementById("lightboxModal");
@@ -144,15 +144,11 @@ function setupImageForLightbox(e) {
 }
 
 function modalZoomSet(modalImage, enable) {
-    if (enable) {
-        modalImage.classList.add('modalImageFullscreen');
-    } else {
-        modalImage.classList.remove('modalImageFullscreen');
-    }
+    if(modalImage) modalImage.classList.toggle('modalImageFullscreen', !!enable);
 }
 
 function modalZoomToggle(event) {
-    modalImage = gradioApp().getElementById("modalImage");
+    var modalImage = gradioApp().getElementById("modalImage");
     modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen'))
     event.stopPropagation()
 }
@@ -179,7 +175,7 @@ function galleryImageHandler(e) {
 }
 
 onUiUpdate(function() {
-    fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img')
+    var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img')
     if (fullImg_preview != null) {
         fullImg_preview.forEach(setupImageForLightbox);
     }
@@ -251,8 +247,11 @@ document.addEventListener("DOMContentLoaded", function() {
 
     modal.appendChild(modalNext)
 
-    gradioApp().appendChild(modal)
-
+    try {
+		gradioApp().appendChild(modal);
+	} catch (e) {
+		gradioApp().body.appendChild(modal);
+	}
 
     document.body.appendChild(modal);
 

+ 57 - 0
javascript/imageviewerGamepad.js

@@ -0,0 +1,57 @@
+window.addEventListener('gamepadconnected', (e) => {
+    const index = e.gamepad.index;
+    let isWaiting = false;
+    setInterval(async () => {
+        if (!opts.js_modal_lightbox_gamepad || isWaiting) return;
+        const gamepad = navigator.getGamepads()[index];
+        const xValue = gamepad.axes[0];
+        if (xValue <= -0.3) {
+            modalPrevImage(e);
+            isWaiting = true;
+        } else if (xValue >= 0.3) {
+            modalNextImage(e);
+            isWaiting = true;
+        }
+        if (isWaiting) {
+            await sleepUntil(() => {
+                const xValue = navigator.getGamepads()[index].axes[0]
+                if (xValue < 0.3 && xValue > -0.3) {
+                    return true;
+                }
+            }, opts.js_modal_lightbox_gamepad_repeat);
+            isWaiting = false;
+        }
+    }, 10);
+});
+
+/*
+Primarily for vr controller type pointer devices.
+I use the wheel event because there's currently no way to do it properly with web xr.
+ */
+let isScrolling = false;
+window.addEventListener('wheel', (e) => {
+    if (!opts.js_modal_lightbox_gamepad || isScrolling) return;
+    isScrolling = true;
+
+    if (e.deltaX <= -0.6) {
+        modalPrevImage(e);
+    } else if (e.deltaX >= 0.6) {
+        modalNextImage(e);
+    }
+
+    setTimeout(() => {
+        isScrolling = false;
+    }, opts.js_modal_lightbox_gamepad_repeat);
+});
+
+function sleepUntil(f, timeout) {
+    return new Promise((resolve) => {
+        const timeStart = new Date();
+        const wait = setInterval(function() {
+            if (f() || new Date() - timeStart > timeout) {
+                clearInterval(wait);
+                resolve();
+            }
+        }, 20);
+    });
+}

+ 44 - 32
javascript/localization.js

@@ -25,6 +25,10 @@ re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u
 original_lines = {}
 translated_lines = {}
 
+function hasLocalization() {
+    return window.localization && Object.keys(window.localization).length > 0;
+}
+
 function textNodesUnder(el){
     var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false);
     while(n=walk.nextNode()) a.push(n);
@@ -35,11 +39,11 @@ function canBeTranslated(node, text){
     if(! text) return false;
     if(! node.parentElement) return false;
 
-    parentType = node.parentElement.nodeName
+    var parentType = node.parentElement.nodeName
     if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false;
 
     if (parentType=='OPTION' || parentType=='SPAN'){
-        pnode = node
+        var pnode = node
         for(var level=0; level<4; level++){
             pnode = pnode.parentElement
             if(! pnode) break;
@@ -69,7 +73,7 @@ function getTranslation(text){
 }
 
 function processTextNode(node){
-    text = node.textContent.trim()
+    var text = node.textContent.trim()
 
     if(! canBeTranslated(node, text)) return
 
@@ -105,30 +109,52 @@ function processNode(node){
 }
 
 function dumpTranslations(){
-    dumped = {}
+    if(!hasLocalization()) {
+        // If we don't have any localization,
+        // we will not have traversed the app to find
+        // original_lines, so do that now.
+        processNode(gradioApp());
+    }
+    var dumped = {}
     if (localization.rtl) {
-        dumped.rtl = true
+        dumped.rtl = true;
     }
 
-    Object.keys(original_lines).forEach(function(text){
-        if(dumped[text] !== undefined)  return
+    for (const text in original_lines) {
+        if(dumped[text] !== undefined) continue;
+        dumped[text] = localization[text] || text;
+    }
 
-        dumped[text] = localization[text] || text
-    })
+    return dumped;
+}
+
+function download_localization() {
+    var text = JSON.stringify(dumpTranslations(), null, 4)
 
-    return dumped
+    var element = document.createElement('a');
+    element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
+    element.setAttribute('download', "localization.json");
+    element.style.display = 'none';
+    document.body.appendChild(element);
+
+    element.click();
+
+    document.body.removeChild(element);
 }
 
-onUiUpdate(function(m){
-    m.forEach(function(mutation){
-        mutation.addedNodes.forEach(function(node){
-            processNode(node)
-        })
-    });
-})
+document.addEventListener("DOMContentLoaded", function () {
+    if (!hasLocalization()) {
+        return;
+    }
 
+    onUiUpdate(function (m) {
+        m.forEach(function (mutation) {
+            mutation.addedNodes.forEach(function (node) {
+                processNode(node)
+            })
+        });
+    })
 
-document.addEventListener("DOMContentLoaded", function() {
     processNode(gradioApp())
 
     if (localization.rtl) {  // if the language is from right to left,
@@ -149,17 +175,3 @@ document.addEventListener("DOMContentLoaded", function() {
         })).observe(gradioApp(), { childList: true });
     }
 })
-
-function download_localization() {
-    text = JSON.stringify(dumpTranslations(), null, 4)
-
-    var element = document.createElement('a');
-    element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
-    element.setAttribute('download', "localization.json");
-    element.style.display = 'none';
-    document.body.appendChild(element);
-
-    element.click();
-
-    document.body.removeChild(element);
-}

+ 3 - 3
javascript/notification.js

@@ -2,15 +2,15 @@
 
 let lastHeadImg = null;
 
-notificationButton = null
+let notificationButton = null;
 
 onUiUpdate(function(){
     if(notificationButton == null){
         notificationButton = gradioApp().getElementById('request_notifications')
 
         if(notificationButton != null){
-            notificationButton.addEventListener('click', function (evt) {
-                Notification.requestPermission();
+            notificationButton.addEventListener('click', () => {
+                void Notification.requestPermission();
             },true);
         }
     }

+ 5 - 6
javascript/progressbar.js

@@ -1,16 +1,15 @@
 // code related to showing and updating progressbar shown as the image is being made
 
-function rememberGallerySelection(id_gallery){
+function rememberGallerySelection(){
 
 }
 
-function getGallerySelectedIndex(id_gallery){
+function getGallerySelectedIndex(){
 
 }
 
 function request(url, data, handler, errorHandler){
     var xhr = new XMLHttpRequest();
-    var url = url;
     xhr.open("POST", url, true);
     xhr.setRequestHeader("Content-Type", "application/json");
     xhr.onreadystatechange = function () {
@@ -66,7 +65,7 @@ function randomId(){
 // starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and
 // preview inside gallery element. Cleans up all created stuff when the task is over and calls atEnd.
 // calls onProgress every time there is a progress update
-function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress){
+function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress, inactivityTimeout=40){
     var dateStart = new Date()
     var wasEverActive = false
     var parentProgressbar = progressbarContainer.parentNode
@@ -107,7 +106,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
                 divProgress.style.width = rect.width + "px";
             }
 
-            progressText = ""
+            let progressText = ""
 
             divInner.style.width = ((res.progress || 0) * 100.0) + '%'
             divInner.style.background = res.progress ? "" : "transparent"
@@ -138,7 +137,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
                 return
             }
 
-            if(elapsedFromStart > 5 && !res.queued && !res.active){
+            if(elapsedFromStart > inactivityTimeout && !res.queued && !res.active){
                 removeProgressBar()
                 return
             }

+ 94 - 14
javascript/ui.js

@@ -1,7 +1,7 @@
 // various functions for interaction with ui.py not large enough to warrant putting them in separate files
 
 function set_theme(theme){
-    gradioURL = window.location.href
+    var gradioURL = window.location.href
     if (!gradioURL.includes('?__theme=')) {
       window.location.replace(gradioURL + '?__theme=' + theme);
     }
@@ -47,7 +47,7 @@ function extract_image_from_gallery(gallery){
         return [gallery[0]];
     }
 
-    index = selected_gallery_index()
+    var index = selected_gallery_index()
 
     if (index < 0 || index >= gallery.length){
         // Use the first image in the gallery as the default
@@ -58,7 +58,7 @@ function extract_image_from_gallery(gallery){
 }
 
 function args_to_array(args){
-    res = []
+    var res = []
     for(var i=0;i<args.length;i++){
         res.push(args[i])
     }
@@ -138,7 +138,7 @@ function get_img2img_tab_index() {
 }
 
 function create_submit_args(args){
-    res = []
+    var res = []
     for(var i=0;i<args.length;i++){
         res.push(args[i])
     }
@@ -159,14 +159,24 @@ function showSubmitButtons(tabname, show){
     gradioApp().getElementById(tabname+'_skip').style.display = show ? "none" : "block"
 }
 
+function showRestoreProgressButton(tabname, show){
+    var button = gradioApp().getElementById(tabname + "_restore_progress")
+    if(! button) return
+
+    button.style.display = show ? "flex" : "none"
+}
+
 function submit(){
     rememberGallerySelection('txt2img_gallery')
     showSubmitButtons('txt2img', false)
 
     var id = randomId()
+    localStorage.setItem("txt2img_task_id", id);
+
     requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function(){
         showSubmitButtons('txt2img', true)
-
+        localStorage.removeItem("txt2img_task_id")
+        showRestoreProgressButton('txt2img', false)
     })
 
     var res = create_submit_args(arguments)
@@ -181,8 +191,12 @@ function submit_img2img(){
     showSubmitButtons('img2img', false)
 
     var id = randomId()
+    localStorage.setItem("img2img_task_id", id);
+
     requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){
         showSubmitButtons('img2img', true)
+        localStorage.removeItem("img2img_task_id")
+        showRestoreProgressButton('img2img', false)
     })
 
     var res = create_submit_args(arguments)
@@ -193,6 +207,42 @@ function submit_img2img(){
     return res
 }
 
+function restoreProgressTxt2img(){
+    showRestoreProgressButton("txt2img", false)
+    var id = localStorage.getItem("txt2img_task_id")
+
+    id = localStorage.getItem("txt2img_task_id")
+
+    if(id) {
+        requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function(){
+            showSubmitButtons('txt2img', true)
+        }, null, 0)
+    }
+
+    return id
+}
+
+function restoreProgressImg2img(){
+    showRestoreProgressButton("img2img", false)
+    
+    var id = localStorage.getItem("img2img_task_id")
+
+    if(id) {
+        requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){
+            showSubmitButtons('img2img', true)
+        }, null, 0)
+    }
+
+    return id
+}
+
+
+onUiLoaded(function () {
+    showRestoreProgressButton('txt2img', localStorage.getItem("txt2img_task_id"))
+    showRestoreProgressButton('img2img', localStorage.getItem("img2img_task_id"))
+});
+
+
 function modelmerger(){
     var id = randomId()
     requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){})
@@ -204,7 +254,7 @@ function modelmerger(){
 
 
 function ask_for_style_name(_, prompt_text, negative_prompt_text) {
-    name_ = prompt('Style name:')
+    var name_ = prompt('Style name:')
     return [name_, prompt_text, negative_prompt_text]
 }
 
@@ -239,11 +289,11 @@ function recalculate_prompts_img2img(){
 }
 
 
-opts = {}
+var opts = {}
 onUiUpdate(function(){
 	if(Object.keys(opts).length != 0) return;
 
-	json_elem = gradioApp().getElementById('settings_json')
+	var json_elem = gradioApp().getElementById('settings_json')
 	if(json_elem == null) return;
 
     var textarea = json_elem.querySelector('textarea')
@@ -292,12 +342,15 @@ onUiUpdate(function(){
     registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button')
     registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button')
 
-    show_all_pages = gradioApp().getElementById('settings_show_all_pages')
-    settings_tabs = gradioApp().querySelector('#settings div')
+    var show_all_pages = gradioApp().getElementById('settings_show_all_pages')
+    var settings_tabs = gradioApp().querySelector('#settings div')
     if(show_all_pages && settings_tabs){
         settings_tabs.appendChild(show_all_pages)
         show_all_pages.onclick = function(){
             gradioApp().querySelectorAll('#settings > div').forEach(function(elem){
+                if(elem.id == "settings_tab_licenses")
+                    return;
+
                 elem.style.display = "block";
             })
         }
@@ -305,9 +358,9 @@ onUiUpdate(function(){
 })
 
 onOptionsChanged(function(){
-    elem = gradioApp().getElementById('sd_checkpoint_hash')
-    sd_checkpoint_hash = opts.sd_checkpoint_hash || ""
-    shorthash = sd_checkpoint_hash.substr(0,10)
+    var elem = gradioApp().getElementById('sd_checkpoint_hash')
+    var sd_checkpoint_hash = opts.sd_checkpoint_hash || ""
+    var shorthash = sd_checkpoint_hash.substring(0,10)
 
 	if(elem && elem.textContent != shorthash){
 	    elem.textContent = shorthash
@@ -342,7 +395,16 @@ function update_token_counter(button_id) {
 
 function restart_reload(){
     document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
-    setTimeout(function(){location.reload()},2000)
+
+    var requestPing = function(){
+        requestGet("./internal/ping", {}, function(data){
+            location.reload();
+        }, function(){
+            setTimeout(requestPing, 500);
+        })
+    }
+
+    setTimeout(requestPing, 2000);
 
     return []
 }
@@ -362,6 +424,23 @@ function selectCheckpoint(name){
     gradioApp().getElementById('change_checkpoint').click()
 }
 
+function currentImg2imgSourceResolution(_, _, scaleBy){
+    var img = gradioApp().querySelector('#mode_img2img > div[style="display: block;"] img')
+    return img ? [img.naturalWidth, img.naturalHeight, scaleBy] : [0, 0, scaleBy]
+}
+
+function updateImg2imgResizeToTextAfterChangingImage(){
+    // At the time this is called from gradio, the image has no yet been replaced.
+    // There may be a better solution, but this is simple and straightforward so I'm going with it.
+
+    setTimeout(function() {
+        gradioApp().getElementById('img2img_update_resize_to').click()
+    }, 500);
+
+    return []
+
+}
+
 function setRandomSeed(target_interface) {
     let seed = gradioApp().querySelector(`#${target_interface}_seed input`);
     if (!seed) {
@@ -409,3 +488,4 @@ function switchWidthHeightImg2Img() {
     height.dispatchEvent(new Event("input"));
     return [];
 }
+

+ 62 - 0
javascript/ui_settings_hints.js

@@ -0,0 +1,62 @@
+// various hints and extra info for the settings tab
+
+settingsHintsSetup = false
+
+onOptionsChanged(function(){
+    if(settingsHintsSetup) return
+    settingsHintsSetup = true
+
+    gradioApp().querySelectorAll('#settings [id^=setting_]').forEach(function(div){
+        var name = div.id.substr(8)
+        var commentBefore = opts._comments_before[name]
+        var commentAfter = opts._comments_after[name]
+
+        if(! commentBefore && !commentAfter) return
+
+        var span = null
+        if(div.classList.contains('gradio-checkbox')) span = div.querySelector('label span')
+        else if(div.classList.contains('gradio-checkboxgroup')) span = div.querySelector('span').firstChild
+        else if(div.classList.contains('gradio-radio')) span = div.querySelector('span').firstChild
+        else span = div.querySelector('label span').firstChild
+
+        if(!span) return
+
+        if(commentBefore){
+            var comment = document.createElement('DIV')
+            comment.className = 'settings-comment'
+            comment.innerHTML = commentBefore
+            span.parentElement.insertBefore(document.createTextNode('\xa0'), span)
+            span.parentElement.insertBefore(comment, span)
+            span.parentElement.insertBefore(document.createTextNode('\xa0'), span)
+        }
+        if(commentAfter){
+            var comment = document.createElement('DIV')
+            comment.className = 'settings-comment'
+            comment.innerHTML = commentAfter
+            span.parentElement.insertBefore(comment, span.nextSibling)
+            span.parentElement.insertBefore(document.createTextNode('\xa0'), span.nextSibling)
+        }
+    })
+})
+
+function settingsHintsShowQuicksettings(){
+	requestGet("./internal/quicksettings-hint", {}, function(data){
+		var table = document.createElement('table')
+		table.className = 'settings-value-table'
+
+		data.forEach(function(obj){
+			var tr = document.createElement('tr')
+			var td = document.createElement('td')
+			td.textContent = obj.name
+			tr.appendChild(td)
+
+			var td = document.createElement('td')
+			td.textContent = obj.label
+			tr.appendChild(td)
+
+			table.appendChild(tr)
+		})
+
+		popup(table);
+	})
+}

+ 58 - 58
launch.py

@@ -3,25 +3,23 @@ import subprocess
 import os
 import sys
 import importlib.util
-import shlex
 import platform
 import json
+from functools import lru_cache
 
 from modules import cmd_args
 from modules.paths_internal import script_path, extensions_dir
 
-commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
-sys.argv += shlex.split(commandline_args)
-
 args, _ = cmd_args.parser.parse_known_args()
 
 python = sys.executable
 git = os.environ.get('GIT', "git")
 index_url = os.environ.get('INDEX_URL', "")
-stored_commit_hash = None
-skip_install = False
 dir_repos = "repositories"
 
+# Whether to default to printing command output
+default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
+
 if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
     os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
 
@@ -49,7 +47,7 @@ or any other error regarding unsuccessful package (library) installation,
 please downgrade (or upgrade) to the latest version of 3.10 Python
 and delete current Python and "venv" folder in WebUI's directory.
 
-You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/
+You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3106/
 
 {"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
 
@@ -57,51 +55,52 @@ Use --skip-python-version-check to suppress this warning.
 """)
 
 
+@lru_cache()
 def commit_hash():
-    global stored_commit_hash
+    try:
+        return subprocess.check_output([git, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
+    except Exception:
+        return "<none>"
 
-    if stored_commit_hash is not None:
-        return stored_commit_hash
 
+@lru_cache()
+def git_tag():
     try:
-        stored_commit_hash = run(f"{git} rev-parse HEAD").strip()
+        return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
     except Exception:
-        stored_commit_hash = "<none>"
+        return "<none>"
 
-    return stored_commit_hash
 
-
-def run(command, desc=None, errdesc=None, custom_env=None, live=False):
+def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
     if desc is not None:
         print(desc)
 
-    if live:
-        result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
-        if result.returncode != 0:
-            raise RuntimeError(f"""{errdesc or 'Error running command'}.
-Command: {command}
-Error code: {result.returncode}""")
+    run_kwargs = {
+        "args": command,
+        "shell": True,
+        "env": os.environ if custom_env is None else custom_env,
+        "encoding": 'utf8',
+        "errors": 'ignore',
+    }
 
-        return ""
+    if not live:
+        run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
 
-    result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
+    result = subprocess.run(**run_kwargs)
 
     if result.returncode != 0:
+        error_bits = [
+            f"{errdesc or 'Error running command'}.",
+            f"Command: {command}",
+            f"Error code: {result.returncode}",
+        ]
+        if result.stdout:
+            error_bits.append(f"stdout: {result.stdout}")
+        if result.stderr:
+            error_bits.append(f"stderr: {result.stderr}")
+        raise RuntimeError("\n".join(error_bits))
 
-        message = f"""{errdesc or 'Error running command'}.
-Command: {command}
-Error code: {result.returncode}
-stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'}
-stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'}
-"""
-        raise RuntimeError(message)
-
-    return result.stdout.decode(encoding="utf8", errors="ignore")
-
-
-def check_run(command):
-    result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
-    return result.returncode == 0
+    return (result.stdout or "")
 
 
 def is_installed(package):
@@ -117,20 +116,17 @@ def repo_dir(name):
     return os.path.join(script_path, dir_repos, name)
 
 
-def run_python(code, desc=None, errdesc=None):
-    return run(f'"{python}" -c "{code}"', desc, errdesc)
-
-
-def run_pip(args, desc=None):
-    if skip_install:
+def run_pip(command, desc=None, live=default_command_live):
+    if args.skip_install:
         return
 
     index_url_line = f' --index-url {index_url}' if index_url != '' else ''
-    return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
+    return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
 
 
-def check_run_python(code):
-    return check_run(f'"{python}" -c "{code}"')
+def check_run_python(code: str) -> bool:
+    result = subprocess.run([python, "-c", code], capture_output=True, shell=False)
+    return result.returncode == 0
 
 
 def git_clone(url, dir, name, commithash=None):
@@ -223,15 +219,14 @@ def run_extensions_installers(settings_file):
 
 
 def prepare_environment():
-    global skip_install
-
-    torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117")
+    torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
+    torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
     requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
 
-    xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.16rc425')
-    gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
-    clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
-    openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
+    xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
+    gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "https://github.com/TencentARC/GFPGAN/archive/8d2447a2d918f8eba5a4a01463fd48e45126a379.zip")
+    clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
+    openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
 
     stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
     taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
@@ -249,15 +244,20 @@ def prepare_environment():
         check_python_version()
 
     commit = commit_hash()
+    tag = git_tag()
 
     print(f"Python {sys.version}")
+    print(f"Version: {tag}")
     print(f"Commit hash: {commit}")
 
     if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
         run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
 
-    if not args.skip_torch_cuda_test:
-        run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
+    if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
+        raise RuntimeError(
+            'Torch is not able to use GPU; '
+            'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'
+        )
 
     if not is_installed("gfpgan"):
         run_pip(f"install {gfpgan_package}", "gfpgan")
@@ -271,7 +271,7 @@ def prepare_environment():
     if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
         if platform.system() == "Windows":
             if platform.python_version().startswith("3.10"):
-                run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
+                run_pip(f"install -U -I --no-deps {xformers_package}", "xformers", live=True)
             else:
                 print("Installation of xformers is not supported in this version of Python.")
                 print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
@@ -296,7 +296,7 @@ def prepare_environment():
 
     if not os.path.isfile(requirements_file):
         requirements_file = os.path.join(script_path, requirements_file)
-    run_pip(f"install -r \"{requirements_file}\"", "requirements for Web UI")
+    run_pip(f"install -r \"{requirements_file}\"", "requirements")
 
     run_extensions_installers(settings_file=args.ui_settings_file)
 
@@ -305,7 +305,7 @@ def prepare_environment():
 
     if args.update_all_extensions:
         git_pull_recursive(extensions_dir)
-    
+
     if "--exit" in sys.argv:
         print("Exiting because of --exit argument")
         exit(0)

BIN
modules/Roboto-Regular.ttf


+ 102 - 88
modules/api/api.py

@@ -6,7 +6,6 @@ import uvicorn
 import gradio as gr
 from threading import Lock
 from io import BytesIO
-from gradio.processing_utils import decode_base64_to_file
 from fastapi import APIRouter, Depends, FastAPI, Request, Response
 from fastapi.security import HTTPBasic, HTTPBasicCredentials
 from fastapi.exceptions import HTTPException
@@ -16,7 +15,8 @@ from secrets import compare_digest
 
 import modules.shared as shared
 from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
-from modules.api.models import *
+from modules.api import models
+from modules.shared import opts
 from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
 from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
 from modules.textual_inversion.preprocess import preprocess
@@ -26,21 +26,24 @@ from modules.sd_models import checkpoints_list, unload_model_weights, reload_mod
 from modules.sd_models_config import find_checkpoint_config_near_filename
 from modules.realesrgan_model import get_realesrgan_models
 from modules import devices
-from typing import List
+from typing import Dict, List, Any
 import piexif
 import piexif.helper
 
+
 def upscaler_to_index(name: str):
     try:
         return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
-    except:
-        raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}")
+    except Exception as e:
+        raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e
+
 
 def script_name_to_index(name, scripts):
     try:
         return [script.title().lower() for script in scripts].index(name.lower())
-    except:
-        raise HTTPException(status_code=422, detail=f"Script '{name}' not found")
+    except Exception as e:
+        raise HTTPException(status_code=422, detail=f"Script '{name}' not found") from e
+
 
 def validate_sampler_name(name):
     config = sd_samplers.all_samplers_map.get(name, None)
@@ -49,20 +52,23 @@ def validate_sampler_name(name):
 
     return name
 
+
 def setUpscalers(req: dict):
     reqDict = vars(req)
     reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
     reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
     return reqDict
 
+
 def decode_base64_to_image(encoding):
     if encoding.startswith("data:image/"):
         encoding = encoding.split(";")[1].split(",")[1]
     try:
         image = Image.open(BytesIO(base64.b64decode(encoding)))
         return image
-    except Exception as err:
-        raise HTTPException(status_code=500, detail="Invalid encoded image")
+    except Exception as e:
+        raise HTTPException(status_code=500, detail="Invalid encoded image") from e
+
 
 def encode_pil_to_base64(image):
     with io.BytesIO() as output_bytes:
@@ -93,6 +99,7 @@ def encode_pil_to_base64(image):
 
     return base64.b64encode(bytes_data)
 
+
 def api_middleware(app: FastAPI):
     rich_available = True
     try:
@@ -100,7 +107,7 @@ def api_middleware(app: FastAPI):
         import starlette # importing just so it can be placed on silent list
         from rich.console import Console
         console = Console()
-    except:
+    except Exception:
         import traceback
         rich_available = False
 
@@ -131,8 +138,8 @@ def api_middleware(app: FastAPI):
             "body": vars(e).get('body', ''),
             "errors": str(e),
         }
-        print(f"API error: {request.method}: {request.url} {err}")
         if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
+            print(f"API error: {request.method}: {request.url} {err}")
             if rich_available:
                 console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
             else:
@@ -158,7 +165,7 @@ def api_middleware(app: FastAPI):
 class Api:
     def __init__(self, app: FastAPI, queue_lock: Lock):
         if shared.cmd_opts.api_auth:
-            self.credentials = dict()
+            self.credentials = {}
             for auth in shared.cmd_opts.api_auth.split(","):
                 user, password = auth.split(":")
                 self.credentials[user] = password
@@ -167,36 +174,37 @@ class Api:
         self.app = app
         self.queue_lock = queue_lock
         api_middleware(self.app)
-        self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
-        self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
-        self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
-        self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
-        self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
-        self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
+        self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=models.TextToImageResponse)
+        self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=models.ImageToImageResponse)
+        self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
+        self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
+        self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
+        self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
         self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
         self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
         self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
-        self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
+        self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
         self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
-        self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
-        self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
-        self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
-        self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
-        self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
-        self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
-        self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
-        self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
-        self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
+        self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
+        self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
+        self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
+        self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
+        self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
+        self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
+        self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
+        self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
+        self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
         self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
-        self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
-        self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
-        self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
-        self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
-        self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
-        self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
+        self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
+        self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
+        self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
+        self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
+        self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
+        self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
         self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
         self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
-        self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
+        self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
+        self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
 
         self.default_script_arg_txt2img = []
         self.default_script_arg_img2img = []
@@ -220,17 +228,25 @@ class Api:
         script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
         script = script_runner.selectable_scripts[script_idx]
         return script, script_idx
-    
+
     def get_scripts_list(self):
-        t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
-        i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
+        t2ilist = [script.name for script in scripts.scripts_txt2img.scripts if script.name is not None]
+        i2ilist = [script.name for script in scripts.scripts_img2img.scripts if script.name is not None]
 
-        return ScriptsList(txt2img = t2ilist, img2img = i2ilist)  
+        return models.ScriptsList(txt2img=t2ilist, img2img=i2ilist)
+
+    def get_script_info(self):
+        res = []
+
+        for script_list in [scripts.scripts_txt2img.scripts, scripts.scripts_img2img.scripts]:
+            res += [script.api_info for script in script_list if script.api_info is not None]
+
+        return res
 
     def get_script(self, script_name, script_runner):
         if script_name is None or script_name == "":
             return None, None
-        
+
         script_idx = script_name_to_index(script_name, script_runner.scripts)
         return script_runner.scripts[script_idx]
 
@@ -265,17 +281,19 @@ class Api:
         if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
             for alwayson_script_name in request.alwayson_scripts.keys():
                 alwayson_script = self.get_script(alwayson_script_name, script_runner)
-                if alwayson_script == None:
+                if alwayson_script is None:
                     raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
                 # Selectable script in always on script param check
-                if alwayson_script.alwayson == False:
-                    raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params")
+                if alwayson_script.alwayson is False:
+                    raise HTTPException(status_code=422, detail="Cannot have a selectable script in the always on scripts params")
                 # always on script with no arg should always run so you don't really need to add them to the requests
                 if "args" in request.alwayson_scripts[alwayson_script_name]:
-                    script_args[alwayson_script.args_from:alwayson_script.args_to] = request.alwayson_scripts[alwayson_script_name]["args"]
+                    # min between arg length in scriptrunner and arg length in the request
+                    for idx in range(0, min((alwayson_script.args_to - alwayson_script.args_from), len(request.alwayson_scripts[alwayson_script_name]["args"]))):
+                        script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
         return script_args
 
-    def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
+    def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
         script_runner = scripts.scripts_txt2img
         if not script_runner.scripts:
             script_runner.initialize_scripts(False)
@@ -309,7 +327,7 @@ class Api:
             p.outpath_samples = opts.outdir_txt2img_samples
 
             shared.state.begin()
-            if selectable_scripts != None:
+            if selectable_scripts is not None:
                 p.script_args = script_args
                 processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
             else:
@@ -319,9 +337,9 @@ class Api:
 
         b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
 
-        return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
+        return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
 
-    def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
+    def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
         init_images = img2imgreq.init_images
         if init_images is None:
             raise HTTPException(status_code=404, detail="Init image not found")
@@ -366,7 +384,7 @@ class Api:
             p.outpath_samples = opts.outdir_img2img_samples
 
             shared.state.begin()
-            if selectable_scripts != None:
+            if selectable_scripts is not None:
                 p.script_args = script_args
                 processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
             else:
@@ -380,9 +398,9 @@ class Api:
             img2imgreq.init_images = None
             img2imgreq.mask = None
 
-        return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
+        return models.ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
 
-    def extras_single_image_api(self, req: ExtrasSingleImageRequest):
+    def extras_single_image_api(self, req: models.ExtrasSingleImageRequest):
         reqDict = setUpscalers(req)
 
         reqDict['image'] = decode_base64_to_image(reqDict['image'])
@@ -390,31 +408,26 @@ class Api:
         with self.queue_lock:
             result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
 
-        return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
+        return models.ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
 
-    def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
+    def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest):
         reqDict = setUpscalers(req)
 
-        def prepareFiles(file):
-            file = decode_base64_to_file(file.data, file_path=file.name)
-            file.orig_name = file.name
-            return file
-
-        reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
-        reqDict.pop('imageList')
+        image_list = reqDict.pop('imageList', [])
+        image_folder = [decode_base64_to_image(x.data) for x in image_list]
 
         with self.queue_lock:
-            result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
+            result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
 
-        return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
+        return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
 
-    def pnginfoapi(self, req: PNGInfoRequest):
+    def pnginfoapi(self, req: models.PNGInfoRequest):
         if(not req.image.strip()):
-            return PNGInfoResponse(info="")
+            return models.PNGInfoResponse(info="")
 
         image = decode_base64_to_image(req.image.strip())
         if image is None:
-            return PNGInfoResponse(info="")
+            return models.PNGInfoResponse(info="")
 
         geninfo, items = images.read_info_from_image(image)
         if geninfo is None:
@@ -422,13 +435,13 @@ class Api:
 
         items = {**{'parameters': geninfo}, **items}
 
-        return PNGInfoResponse(info=geninfo, items=items)
+        return models.PNGInfoResponse(info=geninfo, items=items)
 
-    def progressapi(self, req: ProgressRequest = Depends()):
+    def progressapi(self, req: models.ProgressRequest = Depends()):
         # copy from check_progress_call of ui.py
 
         if shared.state.job_count == 0:
-            return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
+            return models.ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
 
         # avoid dividing zero
         progress = 0.01
@@ -450,9 +463,9 @@ class Api:
         if shared.state.current_image and not req.skip_current_image:
             current_image = encode_pil_to_base64(shared.state.current_image)
 
-        return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
+        return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
 
-    def interrogateapi(self, interrogatereq: InterrogateRequest):
+    def interrogateapi(self, interrogatereq: models.InterrogateRequest):
         image_b64 = interrogatereq.image
         if image_b64 is None:
             raise HTTPException(status_code=404, detail="Image not found")
@@ -469,7 +482,7 @@ class Api:
             else:
                 raise HTTPException(status_code=404, detail="Model not found")
 
-        return InterrogateResponse(caption=processed)
+        return models.InterrogateResponse(caption=processed)
 
     def interruptapi(self):
         shared.state.interrupt()
@@ -574,36 +587,36 @@ class Api:
             filename = create_embedding(**args) # create empty embedding
             sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
             shared.state.end()
-            return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename))
+            return models.CreateResponse(info=f"create embedding filename: {filename}")
         except AssertionError as e:
             shared.state.end()
-            return TrainResponse(info = "create embedding error: {error}".format(error = e))
+            return models.TrainResponse(info=f"create embedding error: {e}")
 
     def create_hypernetwork(self, args: dict):
         try:
             shared.state.begin()
             filename = create_hypernetwork(**args) # create empty embedding
             shared.state.end()
-            return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename))
+            return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
         except AssertionError as e:
             shared.state.end()
-            return TrainResponse(info = "create hypernetwork error: {error}".format(error = e))
+            return models.TrainResponse(info=f"create hypernetwork error: {e}")
 
     def preprocess(self, args: dict):
         try:
             shared.state.begin()
             preprocess(**args) # quick operation unless blip/booru interrogation is enabled
             shared.state.end()
-            return PreprocessResponse(info = 'preprocess complete')
+            return models.PreprocessResponse(info = 'preprocess complete')
         except KeyError as e:
             shared.state.end()
-            return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e))
+            return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
         except AssertionError as e:
             shared.state.end()
-            return PreprocessResponse(info = "preprocess error: {error}".format(error = e))
+            return models.PreprocessResponse(info=f"preprocess error: {e}")
         except FileNotFoundError as e:
             shared.state.end()
-            return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e))
+            return models.PreprocessResponse(info=f'preprocess error: {e}')
 
     def train_embedding(self, args: dict):
         try:
@@ -621,10 +634,10 @@ class Api:
                 if not apply_optimizations:
                     sd_hijack.apply_optimizations()
                 shared.state.end()
-            return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
+            return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
         except AssertionError as msg:
             shared.state.end()
-            return TrainResponse(info = "train embedding error: {msg}".format(msg = msg))
+            return models.TrainResponse(info=f"train embedding error: {msg}")
 
     def train_hypernetwork(self, args: dict):
         try:
@@ -645,14 +658,15 @@ class Api:
                 if not apply_optimizations:
                     sd_hijack.apply_optimizations()
                 shared.state.end()
-            return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error))
-        except AssertionError as msg:
+            return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
+        except AssertionError:
             shared.state.end()
-            return TrainResponse(info="train embedding error: {error}".format(error=error))
+            return models.TrainResponse(info=f"train embedding error: {error}")
 
     def get_memory(self):
         try:
-            import os, psutil
+            import os
+            import psutil
             process = psutil.Process(os.getpid())
             res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
             ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
@@ -679,10 +693,10 @@ class Api:
                     'events': warnings,
                 }
             else:
-                cuda = { 'error': 'unavailable' }
+                cuda = {'error': 'unavailable'}
         except Exception as err:
-            cuda = { 'error': f'{err}' }
-        return MemoryResponse(ram = ram, cuda = cuda)
+            cuda = {'error': f'{err}'}
+        return models.MemoryResponse(ram=ram, cuda=cuda)
 
     def launch(self, server_name, port):
         self.app.include_router(self.router)

+ 22 - 4
modules/api/models.py

@@ -223,8 +223,9 @@ for key in _options:
     if(_options[key].dest != 'help'):
         flag = _options[key]
         _type = str
-        if _options[key].default is not None: _type = type(_options[key].default)
-        flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
+        if _options[key].default is not None:
+            _type = type(_options[key].default)
+        flags.update({flag.dest: (_type, Field(default=flag.default, description=flag.help))})
 
 FlagsModel = create_model("Flags", **flags)
 
@@ -286,6 +287,23 @@ class MemoryResponse(BaseModel):
     ram: dict = Field(title="RAM", description="System memory stats")
     cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
 
+
 class ScriptsList(BaseModel):
-    txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)")
-    img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)")
+    txt2img: list = Field(default=None, title="Txt2img", description="Titles of scripts (txt2img)")
+    img2img: list = Field(default=None, title="Img2img", description="Titles of scripts (img2img)")
+
+
+class ScriptArg(BaseModel):
+    label: str = Field(default=None, title="Label", description="Name of the argument in UI")
+    value: Optional[Any] = Field(default=None, title="Value", description="Default value of the argument")
+    minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
+    maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
+    step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
+    choices: Optional[List[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
+
+
+class ScriptInfo(BaseModel):
+    name: str = Field(default=None, title="Name", description="Script name")
+    is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script")
+    is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script")
+    args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments")

+ 4 - 2
modules/call_queue.py

@@ -35,6 +35,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
 
             try:
                 res = func(*args, **kwargs)
+                progress.record_results(id_task, res)
             finally:
                 progress.finish_task(id_task)
 
@@ -59,7 +60,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
             max_debug_str_len = 131072 # (1024*1024)/8
 
             print("Error completing request", file=sys.stderr)
-            argStr = f"Arguments: {str(args)} {str(kwargs)}"
+            argStr = f"Arguments: {args} {kwargs}"
             print(argStr[:max_debug_str_len], file=sys.stderr)
             if len(argStr) > max_debug_str_len:
                 print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
@@ -72,7 +73,8 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
             if extra_outputs_array is None:
                 extra_outputs_array = [None, '']
 
-            res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"]
+            error_message = f'{type(e).__name__}: {e}'
+            res = extra_outputs_array + [f"<div class='error'>{html.escape(error_message)}</div>"]
 
         shared.state.skipped = False
         shared.state.interrupted = False

+ 4 - 1
modules/cmd_args.py

@@ -1,6 +1,6 @@
 import argparse
 import os
-from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file
+from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file  # noqa: F401
 
 parser = argparse.ArgumentParser()
 
@@ -95,9 +95,12 @@ parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(
 parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
 parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
 parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
+parser.add_argument("--disable-tls-verify", action="store_false", help="When passed, enables the use of self-signed certificates.", default=None)
 parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
 parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
 parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions")
 parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
 parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
 parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
+parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
+parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')

+ 11 - 13
modules/codeformer/codeformer_arch.py

@@ -1,14 +1,12 @@
 # this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
 
 import math
-import numpy as np
 import torch
 from torch import nn, Tensor
 import torch.nn.functional as F
-from typing import Optional, List
+from typing import Optional
 
-from modules.codeformer.vqgan_arch import *
-from basicsr.utils import get_root_logger
+from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock
 from basicsr.utils.registry import ARCH_REGISTRY
 
 def calc_mean_std(feat, eps=1e-5):
@@ -121,7 +119,7 @@ class TransformerSALayer(nn.Module):
                 tgt_mask: Optional[Tensor] = None,
                 tgt_key_padding_mask: Optional[Tensor] = None,
                 query_pos: Optional[Tensor] = None):
-        
+
         # self attention
         tgt2 = self.norm1(tgt)
         q = k = self.with_pos_embed(tgt2, query_pos)
@@ -161,10 +159,10 @@ class Fuse_sft_block(nn.Module):
 
 @ARCH_REGISTRY.register()
 class CodeFormer(VQAutoEncoder):
-    def __init__(self, dim_embd=512, n_head=8, n_layers=9, 
+    def __init__(self, dim_embd=512, n_head=8, n_layers=9,
                 codebook_size=1024, latent_size=256,
-                connect_list=['32', '64', '128', '256'],
-                fix_modules=['quantize','generator']):
+                connect_list=('32', '64', '128', '256'),
+                fix_modules=('quantize', 'generator')):
         super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
 
         if fix_modules is not None:
@@ -181,14 +179,14 @@ class CodeFormer(VQAutoEncoder):
         self.feat_emb = nn.Linear(256, self.dim_embd)
 
         # transformer
-        self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0) 
+        self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
                                     for _ in range(self.n_layers)])
 
         # logits_predict head
         self.idx_pred_layer = nn.Sequential(
             nn.LayerNorm(dim_embd),
             nn.Linear(dim_embd, codebook_size, bias=False))
-        
+
         self.channels = {
             '16': 512,
             '32': 256,
@@ -223,7 +221,7 @@ class CodeFormer(VQAutoEncoder):
         enc_feat_dict = {}
         out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
         for i, block in enumerate(self.encoder.blocks):
-            x = block(x) 
+            x = block(x)
             if i in out_list:
                 enc_feat_dict[str(x.shape[-1])] = x.clone()
 
@@ -268,11 +266,11 @@ class CodeFormer(VQAutoEncoder):
         fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
 
         for i, block in enumerate(self.generator.blocks):
-            x = block(x) 
+            x = block(x)
             if i in fuse_list: # fuse after i-th block
                 f_size = str(x.shape[-1])
                 if w>0:
                     x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
         out = x
         # logits doesn't need softmax before cross_entropy loss
-        return out, logits, lq_feat
+        return out, logits, lq_feat

+ 21 - 23
modules/codeformer/vqgan_arch.py

@@ -5,17 +5,15 @@ VQGAN code, adapted from the original created by the Unleashing Transformers aut
 https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
 
 '''
-import numpy as np
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
-import copy
 from basicsr.utils import get_root_logger
 from basicsr.utils.registry import ARCH_REGISTRY
 
 def normalize(in_channels):
     return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
-    
+
 
 @torch.jit.script
 def swish(x):
@@ -212,15 +210,15 @@ class AttnBlock(nn.Module):
         # compute attention
         b, c, h, w = q.shape
         q = q.reshape(b, c, h*w)
-        q = q.permute(0, 2, 1)   
+        q = q.permute(0, 2, 1)
         k = k.reshape(b, c, h*w)
-        w_ = torch.bmm(q, k) 
+        w_ = torch.bmm(q, k)
         w_ = w_ * (int(c)**(-0.5))
         w_ = F.softmax(w_, dim=2)
 
         # attend to values
         v = v.reshape(b, c, h*w)
-        w_ = w_.permute(0, 2, 1) 
+        w_ = w_.permute(0, 2, 1)
         h_ = torch.bmm(v, w_)
         h_ = h_.reshape(b, c, h, w)
 
@@ -272,18 +270,18 @@ class Encoder(nn.Module):
     def forward(self, x):
         for block in self.blocks:
             x = block(x)
-            
+
         return x
 
 
 class Generator(nn.Module):
     def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
         super().__init__()
-        self.nf = nf 
-        self.ch_mult = ch_mult 
+        self.nf = nf
+        self.ch_mult = ch_mult
         self.num_resolutions = len(self.ch_mult)
         self.num_res_blocks = res_blocks
-        self.resolution = img_size 
+        self.resolution = img_size
         self.attn_resolutions = attn_resolutions
         self.in_channels = emb_dim
         self.out_channels = 3
@@ -317,29 +315,29 @@ class Generator(nn.Module):
         blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
 
         self.blocks = nn.ModuleList(blocks)
-   
+
 
     def forward(self, x):
         for block in self.blocks:
             x = block(x)
-            
+
         return x
 
-  
+
 @ARCH_REGISTRY.register()
 class VQAutoEncoder(nn.Module):
-    def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
+    def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
                 beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
         super().__init__()
         logger = get_root_logger()
-        self.in_channels = 3 
-        self.nf = nf 
-        self.n_blocks = res_blocks 
+        self.in_channels = 3
+        self.nf = nf
+        self.n_blocks = res_blocks
         self.codebook_size = codebook_size
         self.embed_dim = emb_dim
         self.ch_mult = ch_mult
         self.resolution = img_size
-        self.attn_resolutions = attn_resolutions
+        self.attn_resolutions = attn_resolutions or [16]
         self.quantizer_type = quantizer
         self.encoder = Encoder(
             self.in_channels,
@@ -365,11 +363,11 @@ class VQAutoEncoder(nn.Module):
                 self.kl_weight
             )
         self.generator = Generator(
-            self.nf, 
+            self.nf,
             self.embed_dim,
-            self.ch_mult, 
-            self.n_blocks, 
-            self.resolution, 
+            self.ch_mult,
+            self.n_blocks,
+            self.resolution,
             self.attn_resolutions
         )
 
@@ -434,4 +432,4 @@ class VQGANDiscriminator(nn.Module):
                 raise ValueError('Wrong params!')
 
     def forward(self, x):
-        return self.main(x)
+        return self.main(x)

+ 2 - 4
modules/codeformer_model.py

@@ -33,11 +33,9 @@ def setup_model(dirname):
     try:
         from torchvision.transforms.functional import normalize
         from modules.codeformer.codeformer_arch import CodeFormer
-        from basicsr.utils.download_util import load_file_from_url
-        from basicsr.utils import imwrite, img2tensor, tensor2img
+        from basicsr.utils import img2tensor, tensor2img
         from facelib.utils.face_restoration_helper import FaceRestoreHelper
         from facelib.detection.retinaface import retinaface
-        from modules.shared import cmd_opts
 
         net_class = CodeFormer
 
@@ -96,7 +94,7 @@ def setup_model(dirname):
                 self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
                 self.face_helper.align_warp_face()
 
-                for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
+                for cropped_face in self.face_helper.cropped_faces:
                     cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
                     normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
                     cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)

+ 202 - 0
modules/config_states.py

@@ -0,0 +1,202 @@
+"""
+Supports saving and restoring webui and extensions from a known working set of commits
+"""
+
+import os
+import sys
+import traceback
+import json
+import time
+import tqdm
+
+from datetime import datetime
+from collections import OrderedDict
+import git
+
+from modules import shared, extensions
+from modules.paths_internal import script_path, config_states_dir
+
+
+all_config_states = OrderedDict()
+
+
+def list_config_states():
+    global all_config_states
+
+    all_config_states.clear()
+    os.makedirs(config_states_dir, exist_ok=True)
+
+    config_states = []
+    for filename in os.listdir(config_states_dir):
+        if filename.endswith(".json"):
+            path = os.path.join(config_states_dir, filename)
+            with open(path, "r", encoding="utf-8") as f:
+                j = json.load(f)
+                j["filepath"] = path
+                config_states.append(j)
+
+    config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
+
+    for cs in config_states:
+        timestamp = time.asctime(time.gmtime(cs["created_at"]))
+        name = cs.get("name", "Config")
+        full_name = f"{name}: {timestamp}"
+        all_config_states[full_name] = cs
+
+    return all_config_states
+
+
+def get_webui_config():
+    webui_repo = None
+
+    try:
+        if os.path.exists(os.path.join(script_path, ".git")):
+            webui_repo = git.Repo(script_path)
+    except Exception:
+        print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
+        print(traceback.format_exc(), file=sys.stderr)
+
+    webui_remote = None
+    webui_commit_hash = None
+    webui_commit_date = None
+    webui_branch = None
+    if webui_repo and not webui_repo.bare:
+        try:
+            webui_remote = next(webui_repo.remote().urls, None)
+            head = webui_repo.head.commit
+            webui_commit_date = webui_repo.head.commit.committed_date
+            webui_commit_hash = head.hexsha
+            webui_branch = webui_repo.active_branch.name
+
+        except Exception:
+            webui_remote = None
+
+    return {
+        "remote": webui_remote,
+        "commit_hash": webui_commit_hash,
+        "commit_date": webui_commit_date,
+        "branch": webui_branch,
+    }
+
+
+def get_extension_config():
+    ext_config = {}
+
+    for ext in extensions.extensions:
+        ext.read_info_from_repo()
+
+        entry = {
+            "name": ext.name,
+            "path": ext.path,
+            "enabled": ext.enabled,
+            "is_builtin": ext.is_builtin,
+            "remote": ext.remote,
+            "commit_hash": ext.commit_hash,
+            "commit_date": ext.commit_date,
+            "branch": ext.branch,
+            "have_info_from_repo": ext.have_info_from_repo
+        }
+
+        ext_config[ext.name] = entry
+
+    return ext_config
+
+
+def get_config():
+    creation_time = datetime.now().timestamp()
+    webui_config = get_webui_config()
+    ext_config = get_extension_config()
+
+    return {
+        "created_at": creation_time,
+        "webui": webui_config,
+        "extensions": ext_config
+    }
+
+
+def restore_webui_config(config):
+    print("* Restoring webui state...")
+
+    if "webui" not in config:
+        print("Error: No webui data saved to config")
+        return
+
+    webui_config = config["webui"]
+
+    if "commit_hash" not in webui_config:
+        print("Error: No commit saved to webui config")
+        return
+
+    webui_commit_hash = webui_config.get("commit_hash", None)
+    webui_repo = None
+
+    try:
+        if os.path.exists(os.path.join(script_path, ".git")):
+            webui_repo = git.Repo(script_path)
+    except Exception:
+        print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
+        print(traceback.format_exc(), file=sys.stderr)
+        return
+
+    try:
+        webui_repo.git.fetch(all=True)
+        webui_repo.git.reset(webui_commit_hash, hard=True)
+        print(f"* Restored webui to commit {webui_commit_hash}.")
+    except Exception:
+        print(f"Error restoring webui to commit {webui_commit_hash}:", file=sys.stderr)
+        print(traceback.format_exc(), file=sys.stderr)
+
+
+def restore_extension_config(config):
+    print("* Restoring extension state...")
+
+    if "extensions" not in config:
+        print("Error: No extension data saved to config")
+        return
+
+    ext_config = config["extensions"]
+
+    results = []
+    disabled = []
+
+    for ext in tqdm.tqdm(extensions.extensions):
+        if ext.is_builtin:
+            continue
+
+        ext.read_info_from_repo()
+        current_commit = ext.commit_hash
+
+        if ext.name not in ext_config:
+            ext.disabled = True
+            disabled.append(ext.name)
+            results.append((ext, current_commit[:8], False, "Saved extension state not found in config, marking as disabled"))
+            continue
+
+        entry = ext_config[ext.name]
+
+        if "commit_hash" in entry and entry["commit_hash"]:
+            try:
+                ext.fetch_and_reset_hard(entry["commit_hash"])
+                ext.read_info_from_repo()
+                if current_commit != entry["commit_hash"]:
+                    results.append((ext, current_commit[:8], True, entry["commit_hash"][:8]))
+            except Exception as ex:
+                results.append((ext, current_commit[:8], False, ex))
+        else:
+            results.append((ext, current_commit[:8], False, "No commit hash found in config"))
+
+        if not entry.get("enabled", False):
+            ext.disabled = True
+            disabled.append(ext.name)
+        else:
+            ext.disabled = False
+
+    shared.opts.disabled_extensions = disabled
+    shared.opts.save(shared.config_filename)
+
+    print("* Finished restoring extensions. Results:")
+    for ext, prev_commit, success, result in results:
+        if success:
+            print(f"  + {ext.name}: {prev_commit} -> {result}")
+        else:
+            print(f"  ! {ext.name}: FAILURE ({result})")

+ 1 - 2
modules/deepbooru.py

@@ -2,7 +2,6 @@ import os
 import re
 
 import torch
-from PIL import Image
 import numpy as np
 
 from modules import modelloader, paths, deepbooru_model, devices, images, shared
@@ -79,7 +78,7 @@ class DeepDanbooru:
 
         res = []
 
-        filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")])
+        filtertags = {x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")}
 
         for tag in [x for x in tags if x not in filtertags]:
             probability = probability_dict[tag]

+ 7 - 3
modules/devices.py

@@ -65,7 +65,7 @@ def enable_tf32():
 
         # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
         # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
-        if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]):
+        if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
             torch.backends.cudnn.benchmark = True
 
         torch.backends.cuda.matmul.allow_tf32 = True
@@ -92,14 +92,18 @@ def cond_cast_float(input):
 
 
 def randn(seed, shape):
+    from modules.shared import opts
+
     torch.manual_seed(seed)
-    if device.type == 'mps':
+    if opts.randn_source == "CPU" or device.type == 'mps':
         return torch.randn(shape, device=cpu).to(device)
     return torch.randn(shape, device=device)
 
 
 def randn_without_seed(shape):
-    if device.type == 'mps':
+    from modules.shared import opts
+
+    if opts.randn_source == "CPU" or device.type == 'mps':
         return torch.randn(shape, device=cpu).to(device)
     return torch.randn(shape, device=device)
 

+ 10 - 11
modules/esrgan_model.py

@@ -6,7 +6,7 @@ from PIL import Image
 from basicsr.utils.download_util import load_file_from_url
 
 import modules.esrgan_model_arch as arch
-from modules import shared, modelloader, images, devices
+from modules import modelloader, images, devices
 from modules.upscaler import Upscaler, UpscalerData
 from modules.shared import opts
 
@@ -16,9 +16,7 @@ def mod2normal(state_dict):
     # this code is copied from https://github.com/victorca25/iNNfer
     if 'conv_first.weight' in state_dict:
         crt_net = {}
-        items = []
-        for k, v in state_dict.items():
-            items.append(k)
+        items = list(state_dict)
 
         crt_net['model.0.weight'] = state_dict['conv_first.weight']
         crt_net['model.0.bias'] = state_dict['conv_first.bias']
@@ -52,9 +50,7 @@ def resrgan2normal(state_dict, nb=23):
     if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
         re8x = 0
         crt_net = {}
-        items = []
-        for k, v in state_dict.items():
-            items.append(k)
+        items = list(state_dict)
 
         crt_net['model.0.weight'] = state_dict['conv_first.weight']
         crt_net['model.0.bias'] = state_dict['conv_first.bias']
@@ -156,13 +152,16 @@ class UpscalerESRGAN(Upscaler):
 
     def load_model(self, path: str):
         if "http" in path:
-            filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
-                                          file_name="%s.pth" % self.model_name,
-                                          progress=True)
+            filename = load_file_from_url(
+                url=self.model_url,
+                model_dir=self.model_path,
+                file_name=f"{self.model_name}.pth",
+                progress=True,
+            )
         else:
             filename = path
         if not os.path.exists(filename) or filename is None:
-            print("Unable to load %s from %s" % (self.model_path, filename))
+            print(f"Unable to load {self.model_path} from {filename}")
             return None
 
         state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)

+ 12 - 11
modules/esrgan_model_arch.py

@@ -2,7 +2,6 @@
 
 from collections import OrderedDict
 import math
-import functools
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
@@ -38,7 +37,7 @@ class RRDBNet(nn.Module):
         elif upsample_mode == 'pixelshuffle':
             upsample_block = pixelshuffle_block
         else:
-            raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
+            raise NotImplementedError(f'upsample mode [{upsample_mode}] is not found')
         if upscale == 3:
             upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
         else:
@@ -106,7 +105,7 @@ class ResidualDenseBlock_5C(nn.Module):
     Modified options that can be used:
         - "Partial Convolution based Padding" arXiv:1811.11718
         - "Spectral normalization" arXiv:1802.05957
-        - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. 
+        - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
             {Rakotonirina} and A. {Rasoanaivo}
     """
 
@@ -171,7 +170,7 @@ class GaussianNoise(nn.Module):
             scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
             sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
             x = x + sampled_noise
-        return x 
+        return x
 
 def conv1x1(in_planes, out_planes, stride=1):
     return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
@@ -261,10 +260,10 @@ class Upsample(nn.Module):
 
     def extra_repr(self):
         if self.scale_factor is not None:
-            info = 'scale_factor=' + str(self.scale_factor)
+            info = f'scale_factor={self.scale_factor}'
         else:
-            info = 'size=' + str(self.size)
-        info += ', mode=' + self.mode
+            info = f'size={self.size}'
+        info += f', mode={self.mode}'
         return info
 
 
@@ -350,7 +349,7 @@ def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
     elif act_type == 'sigmoid':  # [0, 1] range output
         layer = nn.Sigmoid()
     else:
-        raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
+        raise NotImplementedError(f'activation layer [{act_type}] is not found')
     return layer
 
 
@@ -372,7 +371,7 @@ def norm(norm_type, nc):
     elif norm_type == 'none':
         def norm_layer(x): return Identity()
     else:
-        raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
+        raise NotImplementedError(f'normalization layer [{norm_type}] is not found')
     return layer
 
 
@@ -388,7 +387,7 @@ def pad(pad_type, padding):
     elif pad_type == 'zero':
         layer = nn.ZeroPad2d(padding)
     else:
-        raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
+        raise NotImplementedError(f'padding layer [{pad_type}] is not implemented')
     return layer
 
 
@@ -432,15 +431,17 @@ def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=
                pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
                spectral_norm=False):
     """ Conv layer with padding, normalization, activation """
-    assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
+    assert mode in ['CNA', 'NAC', 'CNAC'], f'Wrong conv mode [{mode}]'
     padding = get_valid_padding(kernel_size, dilation)
     p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
     padding = padding if pad_type == 'zero' else 0
 
     if convtype=='PartialConv2D':
+        from torchvision.ops import PartialConv2d  # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer
         c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
                dilation=dilation, bias=bias, groups=groups)
     elif convtype=='DeformConv2D':
+        from torchvision.ops import DeformConv2d  # not tested
         c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
                dilation=dilation, bias=bias, groups=groups)
     elif convtype=='Conv3D':

+ 39 - 12
modules/extensions.py

@@ -1,12 +1,12 @@
 import os
 import sys
+import threading
 import traceback
 
-import time
 import git
 
 from modules import shared
-from modules.paths_internal import extensions_dir, extensions_builtin_dir
+from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path  # noqa: F401
 
 extensions = []
 
@@ -24,6 +24,8 @@ def active():
 
 
 class Extension:
+    lock = threading.Lock()
+
     def __init__(self, name, path, enabled=True, is_builtin=False):
         self.name = name
         self.path = path
@@ -31,16 +33,24 @@ class Extension:
         self.status = ''
         self.can_update = False
         self.is_builtin = is_builtin
+        self.commit_hash = ''
+        self.commit_date = None
         self.version = ''
+        self.branch = None
         self.remote = None
         self.have_info_from_repo = False
 
     def read_info_from_repo(self):
-        if self.have_info_from_repo:
+        if self.is_builtin or self.have_info_from_repo:
             return
 
-        self.have_info_from_repo = True
+        with self.lock:
+            if self.have_info_from_repo:
+                return
+
+            self.do_read_info_from_repo()
 
+    def do_read_info_from_repo(self):
         repo = None
         try:
             if os.path.exists(os.path.join(self.path, ".git")):
@@ -55,13 +65,18 @@ class Extension:
             try:
                 self.status = 'unknown'
                 self.remote = next(repo.remote().urls, None)
-                head = repo.head.commit
-                ts = time.asctime(time.gmtime(repo.head.commit.committed_date))
-                self.version = f'{head.hexsha[:8]} ({ts})'
-
-            except Exception:
+                self.commit_date = repo.head.commit.committed_date
+                if repo.active_branch:
+                    self.branch = repo.active_branch.name
+                self.commit_hash = repo.head.commit.hexsha
+                self.version = repo.git.describe("--always", "--tags")  # compared to `self.commit_hash[:8]` this takes about 30% more time total but since we run it in parallel we don't care
+
+            except Exception as ex:
+                print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr)
                 self.remote = None
 
+        self.have_info_from_repo = True
+
     def list_files(self, subdir, extension):
         from modules import scripts
 
@@ -82,18 +97,30 @@ class Extension:
         for fetch in repo.remote().fetch(dry_run=True):
             if fetch.flags != fetch.HEAD_UPTODATE:
                 self.can_update = True
-                self.status = "behind"
+                self.status = "new commits"
                 return
 
+        try:
+            origin = repo.rev_parse('origin')
+            if repo.head.commit != origin:
+                self.can_update = True
+                self.status = "behind HEAD"
+                return
+        except Exception:
+            self.can_update = False
+            self.status = "unknown (remote error)"
+            return
+
         self.can_update = False
         self.status = "latest"
 
-    def fetch_and_reset_hard(self):
+    def fetch_and_reset_hard(self, commit='origin'):
         repo = git.Repo(self.path)
         # Fix: `error: Your local changes to the following files would be overwritten by merge`,
         # because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
         repo.git.fetch(all=True)
-        repo.git.reset('origin', hard=True)
+        repo.git.reset(commit, hard=True)
+        self.have_info_from_repo = False
 
 
 def list_extensions():

+ 1 - 1
modules/extra_networks.py

@@ -91,7 +91,7 @@ def deactivate(p, extra_network_data):
     """call deactivate for extra networks in extra_network_data in specified order, then call
     deactivate for all remaining registered networks"""
 
-    for extra_network_name, extra_network_args in extra_network_data.items():
+    for extra_network_name in extra_network_data:
         extra_network = extra_network_registry.get(extra_network_name, None)
         if extra_network is None:
             continue

+ 4 - 3
modules/extra_networks_hypernet.py

@@ -1,4 +1,4 @@
-from modules import extra_networks, shared, extra_networks
+from modules import extra_networks, shared
 from modules.hypernetworks import hypernetwork
 
 
@@ -9,8 +9,9 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
     def activate(self, p, params_list):
         additional = shared.opts.sd_hypernetwork
 
-        if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
-            p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
+        if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
+            hypernet_prompt_text = f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>"
+            p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts]
             params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
 
         names = []

+ 51 - 5
modules/extras.py

@@ -1,6 +1,7 @@
 import os
 import re
 import shutil
+import json
 
 
 import torch
@@ -71,7 +72,7 @@ def to_half(tensor, enable):
     return tensor
 
 
-def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights):
+def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
     shared.state.begin()
     shared.state.job = 'model-merge'
 
@@ -135,14 +136,14 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
     result_is_instruct_pix2pix_model = False
 
     if theta_func2:
-        shared.state.textinfo = f"Loading B"
+        shared.state.textinfo = "Loading B"
         print(f"Loading {secondary_model_info.filename}...")
         theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
     else:
         theta_1 = None
 
     if theta_func1:
-        shared.state.textinfo = f"Loading C"
+        shared.state.textinfo = "Loading C"
         print(f"Loading {tertiary_model_info.filename}...")
         theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
 
@@ -198,7 +199,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
                     result_is_inpainting_model = True
             else:
                 theta_0[key] = theta_func2(a, b, multiplier)
-            
+
             theta_0[key] = to_half(theta_0[key], save_as_half)
 
         shared.state.sampling_step += 1
@@ -241,13 +242,58 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
     shared.state.textinfo = "Saving"
     print(f"Saving to {output_modelname}...")
 
+    metadata = None
+
+    if save_metadata:
+        metadata = {"format": "pt"}
+
+        merge_recipe = {
+            "type": "webui", # indicate this model was merged with webui's built-in merger
+            "primary_model_hash": primary_model_info.sha256,
+            "secondary_model_hash": secondary_model_info.sha256 if secondary_model_info else None,
+            "tertiary_model_hash": tertiary_model_info.sha256 if tertiary_model_info else None,
+            "interp_method": interp_method,
+            "multiplier": multiplier,
+            "save_as_half": save_as_half,
+            "custom_name": custom_name,
+            "config_source": config_source,
+            "bake_in_vae": bake_in_vae,
+            "discard_weights": discard_weights,
+            "is_inpainting": result_is_inpainting_model,
+            "is_instruct_pix2pix": result_is_instruct_pix2pix_model
+        }
+        metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
+
+        sd_merge_models = {}
+
+        def add_model_metadata(checkpoint_info):
+            checkpoint_info.calculate_shorthash()
+            sd_merge_models[checkpoint_info.sha256] = {
+                "name": checkpoint_info.name,
+                "legacy_hash": checkpoint_info.hash,
+                "sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None)
+            }
+
+            sd_merge_models.update(checkpoint_info.metadata.get("sd_merge_models", {}))
+
+        add_model_metadata(primary_model_info)
+        if secondary_model_info:
+            add_model_metadata(secondary_model_info)
+        if tertiary_model_info:
+            add_model_metadata(tertiary_model_info)
+
+        metadata["sd_merge_models"] = json.dumps(sd_merge_models)
+
     _, extension = os.path.splitext(output_modelname)
     if extension.lower() == ".safetensors":
-        safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"})
+        safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata)
     else:
         torch.save(theta_0, output_modelname)
 
     sd_models.list_models()
+    created_model = next((ckpt for ckpt in sd_models.checkpoints_list.values() if ckpt.name == filename), None)
+    if created_model:
+        created_model.calculate_shorthash()
 
     create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
 

+ 19 - 9
modules/generation_parameters_copypaste.py

@@ -1,15 +1,11 @@
 import base64
-import html
 import io
-import math
 import os
 import re
-from pathlib import Path
 
 import gradio as gr
 from modules.paths import data_path
 from modules import shared, ui_tempdir, script_callbacks
-import tempfile
 from PIL import Image
 
 re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
@@ -23,14 +19,14 @@ registered_param_bindings = []
 
 
 class ParamBinding:
-    def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=[]):
+    def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
         self.paste_button = paste_button
         self.tabname = tabname
         self.source_text_component = source_text_component
         self.source_image_component = source_image_component
         self.source_tabname = source_tabname
         self.override_settings_component = override_settings_component
-        self.paste_field_names = paste_field_names
+        self.paste_field_names = paste_field_names or []
 
 
 def reset():
@@ -59,6 +55,7 @@ def image_from_url_text(filedata):
         is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
         assert is_in_right_dir, 'trying to open image file outside of allowed directories'
 
+        filename = filename.rsplit('?', 1)[0]
         return Image.open(filename)
 
     if type(filedata) == list:
@@ -129,6 +126,7 @@ def connect_paste_params_buttons():
                 _js=jsfunc,
                 inputs=[binding.source_image_component],
                 outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
+                show_progress=False,
             )
 
         if binding.source_text_component is not None and fields is not None:
@@ -140,6 +138,7 @@ def connect_paste_params_buttons():
                 fn=lambda *x: x,
                 inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
                 outputs=[field for field, name in fields if name in paste_field_names],
+                show_progress=False,
             )
 
         binding.paste_button.click(
@@ -147,6 +146,7 @@ def connect_paste_params_buttons():
             _js=f"switch_to_{binding.tabname}",
             inputs=None,
             outputs=None,
+            show_progress=False,
         )
 
 
@@ -247,7 +247,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
         lines.append(lastline)
         lastline = ''
 
-    for i, line in enumerate(lines):
+    for line in lines:
         line = line.strip()
         if line.startswith("Negative prompt:"):
             done_with_prompt = True
@@ -265,8 +265,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
         v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
         m = re_imagesize.match(v)
         if m is not None:
-            res[k+"-1"] = m.group(1)
-            res[k+"-2"] = m.group(2)
+            res[f"{k}-1"] = m.group(1)
+            res[f"{k}-2"] = m.group(2)
         else:
             res[k] = v
 
@@ -284,6 +284,10 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
 
     restore_old_hires_fix_params(res)
 
+    # Missing RNG means the default was set, which is GPU RNG
+    if "RNG" not in res:
+        res["RNG"] = "GPU"
+
     return res
 
 
@@ -304,6 +308,10 @@ infotext_to_setting_name_mapping = [
     ('UniPC skip type', 'uni_pc_skip_type'),
     ('UniPC order', 'uni_pc_order'),
     ('UniPC lower order final', 'uni_pc_lower_order_final'),
+    ('Token merging ratio', 'token_merging_ratio'),
+    ('Token merging ratio hr', 'token_merging_ratio_hr'),
+    ('RNG', 'randn_source'),
+    ('NGMS', 's_min_uncond'),
 ]
 
 
@@ -403,12 +411,14 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
         fn=paste_func,
         inputs=[input_comp],
         outputs=[x[0] for x in paste_fields],
+        show_progress=False,
     )
     button.click(
         fn=None,
         _js=f"recalculate_prompts_{tabname}",
         inputs=[],
         outputs=[],
+        show_progress=False,
     )
 
 

+ 1 - 1
modules/gfpgan_model.py

@@ -78,7 +78,7 @@ def setup_model(dirname):
 
     try:
         from gfpgan import GFPGANer
-        from facexlib import detection, parsing
+        from facexlib import detection, parsing  # noqa: F401
         global user_path
         global have_gfpgan
         global gfpgan_constructor

+ 2 - 2
modules/hashes.py

@@ -13,7 +13,7 @@ cache_data = None
 
 
 def dump_cache():
-    with filelock.FileLock(cache_filename+".lock"):
+    with filelock.FileLock(f"{cache_filename}.lock"):
         with open(cache_filename, "w", encoding="utf8") as file:
             json.dump(cache_data, file, indent=4)
 
@@ -22,7 +22,7 @@ def cache(subsection):
     global cache_data
 
     if cache_data is None:
-        with filelock.FileLock(cache_filename+".lock"):
+        with filelock.FileLock(f"{cache_filename}.lock"):
             if not os.path.isfile(cache_filename):
                 cache_data = {}
             else:

+ 14 - 15
modules/hypernetworks/hypernetwork.py

@@ -1,4 +1,3 @@
-import csv
 import datetime
 import glob
 import html
@@ -18,7 +17,7 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
 from torch import einsum
 from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
 
-from collections import defaultdict, deque
+from collections import deque
 from statistics import stdev, mean
 
 
@@ -178,34 +177,34 @@ class Hypernetwork:
 
     def weights(self):
         res = []
-        for k, layers in self.layers.items():
+        for layers in self.layers.values():
             for layer in layers:
                 res += layer.parameters()
         return res
 
     def train(self, mode=True):
-        for k, layers in self.layers.items():
+        for layers in self.layers.values():
             for layer in layers:
                 layer.train(mode=mode)
                 for param in layer.parameters():
                     param.requires_grad = mode
 
     def to(self, device):
-        for k, layers in self.layers.items():
+        for layers in self.layers.values():
             for layer in layers:
                 layer.to(device)
 
         return self
 
     def set_multiplier(self, multiplier):
-        for k, layers in self.layers.items():
+        for layers in self.layers.values():
             for layer in layers:
                 layer.multiplier = multiplier
 
         return self
 
     def eval(self):
-        for k, layers in self.layers.items():
+        for layers in self.layers.values():
             for layer in layers:
                 layer.eval()
                 for param in layer.parameters():
@@ -404,7 +403,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
     k = self.to_k(context_k)
     v = self.to_v(context_v)
 
-    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+    q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
 
     sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
 
@@ -541,7 +540,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
         return hypernetwork, filename
 
     scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
-    
+
     clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
     if clip_grad:
         clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
@@ -594,7 +593,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
             print(e)
 
     scaler = torch.cuda.amp.GradScaler()
-    
+
     batch_size = ds.batch_size
     gradient_step = ds.gradient_step
     # n steps = batch_size * gradient_step * n image processed
@@ -620,7 +619,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
     try:
         sd_hijack_checkpoint.add()
 
-        for i in range((steps-initial_step) * gradient_step):
+        for _ in range((steps-initial_step) * gradient_step):
             if scheduler.finished:
                 break
             if shared.state.interrupted:
@@ -637,7 +636,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
 
                 if clip_grad:
                     clip_grad_sched.step(hypernetwork.step)
-                
+
                 with devices.autocast():
                     x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
                     if use_weight:
@@ -658,14 +657,14 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
 
                     _loss_step += loss.item()
                 scaler.scale(loss).backward()
-                
+
                 # go back until we reach gradient accumulation steps
                 if (j + 1) % gradient_step != 0:
                     continue
                 loss_logging.append(_loss_step)
                 if clip_grad:
                     clip_grad(weights, clip_grad_sched.learn_rate)
-                
+
                 scaler.step(optimizer)
                 scaler.update()
                 hypernetwork.step += 1
@@ -675,7 +674,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
                 _loss_step = 0
 
                 steps_done = hypernetwork.step + 1
-                
+
                 epoch_num = hypernetwork.step // steps_per_epoch
                 epoch_step = hypernetwork.step % steps_per_epoch
 

+ 2 - 4
modules/hypernetworks/ui.py

@@ -1,19 +1,17 @@
 import html
-import os
-import re
 
 import gradio as gr
 import modules.hypernetworks.hypernetwork
 from modules import devices, sd_hijack, shared
 
 not_available = ["hardswish", "multiheadattention"]
-keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
+keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict if x not in not_available]
 
 
 def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
     filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
 
-    return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
+    return gr.Dropdown.update(choices=sorted(shared.hypernetworks)), f"Created: {filename}", ""
 
 
 def train_hypernetwork(*args):

+ 88 - 51
modules/images.py

@@ -13,17 +13,24 @@ import numpy as np
 import piexif
 import piexif.helper
 from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
-from fonts.ttf import Roboto
 import string
 import json
 import hashlib
 
 from modules import sd_samplers, shared, script_callbacks, errors
-from modules.shared import opts, cmd_opts
+from modules.paths_internal import roboto_ttf_file
+from modules.shared import opts
 
 LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
 
 
+def get_font(fontsize: int):
+    try:
+        return ImageFont.truetype(opts.font or roboto_ttf_file, fontsize)
+    except Exception:
+        return ImageFont.truetype(roboto_ttf_file, fontsize)
+
+
 def image_grid(imgs, batch_size=1, rows=None):
     if rows is None:
         if opts.n_rows > 0:
@@ -142,14 +149,8 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
                 lines.append(word)
         return lines
 
-    def get_font(fontsize):
-        try:
-            return ImageFont.truetype(opts.font or Roboto, fontsize)
-        except Exception:
-            return ImageFont.truetype(Roboto, fontsize)
-
     def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
-        for i, line in enumerate(lines):
+        for line in lines:
             fnt = initial_fnt
             fontsize = initial_fontsize
             while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
@@ -318,6 +319,7 @@ re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
 re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
 re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
 max_filename_part_length = 128
+NOTHING_AND_SKIP_PREVIOUS_TEXT = object()
 
 
 def sanitize_filename_part(text, replace_spaces=True):
@@ -352,6 +354,11 @@ class FilenameGenerator:
         'prompt_no_styles': lambda self: self.prompt_no_style(),
         'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
         'prompt_words': lambda self: self.prompt_words(),
+        'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.batch_index + 1,
+        'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
+        'hasprompt': lambda self, *args: self.hasprompt(*args),  # accepts formats:[hasprompt<prompt1|default><prompt2>..]
+        'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
+        'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
     }
     default_time_format = '%Y%m%d%H%M%S'
 
@@ -361,6 +368,22 @@ class FilenameGenerator:
         self.prompt = prompt
         self.image = image
 
+    def hasprompt(self, *args):
+        lower = self.prompt.lower()
+        if self.p is None or self.prompt is None:
+            return None
+        outres = ""
+        for arg in args:
+            if arg != "":
+                division = arg.split("|")
+                expected = division[0].lower()
+                default = division[1] if len(division) > 1 else ""
+                if lower.find(expected) >= 0:
+                    outres = f'{outres}{expected}'
+                else:
+                    outres = outres if default == "" else f'{outres}{default}'
+        return sanitize_filename_part(outres)
+
     def prompt_no_style(self):
         if self.p is None or self.prompt is None:
             return None
@@ -387,13 +410,13 @@ class FilenameGenerator:
         time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
         try:
             time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
-        except pytz.exceptions.UnknownTimeZoneError as _:
+        except pytz.exceptions.UnknownTimeZoneError:
             time_zone = None
 
         time_zone_time = time_datetime.astimezone(time_zone)
         try:
             formatted_time = time_zone_time.strftime(time_format)
-        except (ValueError, TypeError) as _:
+        except (ValueError, TypeError):
             formatted_time = time_zone_time.strftime(self.default_time_format)
 
         return sanitize_filename_part(formatted_time, replace_spaces=False)
@@ -403,9 +426,9 @@ class FilenameGenerator:
 
         for m in re_pattern.finditer(x):
             text, pattern = m.groups()
-            res += text
 
             if pattern is None:
+                res += text
                 continue
 
             pattern_args = []
@@ -426,11 +449,13 @@ class FilenameGenerator:
                     print(f"Error adding [{pattern}] to filename", file=sys.stderr)
                     print(traceback.format_exc(), file=sys.stderr)
 
-                if replacement is not None:
-                    res += str(replacement)
+                if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT:
+                    continue
+                elif replacement is not None:
+                    res += text + str(replacement)
                     continue
 
-            res += f'[{pattern}]'
+            res += f'{text}[{pattern}]'
 
         return res
 
@@ -443,20 +468,57 @@ def get_next_sequence_number(path, basename):
     """
     result = -1
     if basename != '':
-        basename = basename + "-"
+        basename = f"{basename}-"
 
     prefix_length = len(basename)
     for p in os.listdir(path):
         if p.startswith(basename):
-            l = os.path.splitext(p[prefix_length:])[0].split('-')  # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
+            parts = os.path.splitext(p[prefix_length:])[0].split('-')  # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
             try:
-                result = max(int(l[0]), result)
+                result = max(int(parts[0]), result)
             except ValueError:
                 pass
 
     return result + 1
 
 
+def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None):
+    if extension is None:
+        extension = os.path.splitext(filename)[1]
+
+    image_format = Image.registered_extensions()[extension]
+
+    existing_pnginfo = existing_pnginfo or {}
+    if opts.enable_pnginfo:
+        existing_pnginfo['parameters'] = geninfo
+
+    if extension.lower() == '.png':
+        pnginfo_data = PngImagePlugin.PngInfo()
+        for k, v in (existing_pnginfo or {}).items():
+            pnginfo_data.add_text(k, str(v))
+
+        image.save(filename, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
+
+    elif extension.lower() in (".jpg", ".jpeg", ".webp"):
+        if image.mode == 'RGBA':
+            image = image.convert("RGB")
+        elif image.mode == 'I;16':
+            image = image.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
+
+        image.save(filename, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
+
+        if opts.enable_pnginfo and geninfo is not None:
+            exif_bytes = piexif.dump({
+                "Exif": {
+                    piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(geninfo or "", encoding="unicode")
+                },
+            })
+
+            piexif.insert(exif_bytes, filename)
+    else:
+        image.save(filename, format=image_format, quality=opts.jpeg_quality)
+
+
 def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
     """Save an image.
 
@@ -512,7 +574,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
         add_number = opts.save_images_add_number or file_decoration == ''
 
         if file_decoration != "" and add_number:
-            file_decoration = "-" + file_decoration
+            file_decoration = f"-{file_decoration}"
 
         file_decoration = namegen.apply(file_decoration) + suffix
 
@@ -541,38 +603,13 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
     info = params.pnginfo.get(pnginfo_section_name, None)
 
     def _atomically_save_image(image_to_save, filename_without_extension, extension):
-        # save image with .tmp extension to avoid race condition when another process detects new image in the directory
-        temp_file_path = filename_without_extension + ".tmp"
-        image_format = Image.registered_extensions()[extension]
-
-        if extension.lower() == '.png':
-            pnginfo_data = PngImagePlugin.PngInfo()
-            if opts.enable_pnginfo:
-                for k, v in params.pnginfo.items():
-                    pnginfo_data.add_text(k, str(v))
-
-            image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
-
-        elif extension.lower() in (".jpg", ".jpeg", ".webp"):
-            if image_to_save.mode == 'RGBA':
-                image_to_save = image_to_save.convert("RGB")
-            elif image_to_save.mode == 'I;16':
-                image_to_save = image_to_save.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
-
-            image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
-
-            if opts.enable_pnginfo and info is not None:
-                exif_bytes = piexif.dump({
-                    "Exif": {
-                        piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
-                    },
-                })
-
-                piexif.insert(exif_bytes, temp_file_path)
-        else:
-            image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
+        """
+        save image with .tmp extension to avoid race condition when another process detects new image in the directory
+        """
+        temp_file_path = f"{filename_without_extension}.tmp"
+
+        save_image_with_geninfo(image_to_save, info, temp_file_path, extension, params.pnginfo)
 
-        # atomically rename the file with correct extension
         os.replace(temp_file_path, filename_without_extension + extension)
 
     fullfn_without_extension, extension = os.path.splitext(params.filename)
@@ -602,7 +639,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
     if opts.save_txt and info is not None:
         txt_fullfn = f"{fullfn_without_extension}.txt"
         with open(txt_fullfn, "w", encoding="utf8") as file:
-            file.write(info + "\n")
+            file.write(f"{info}\n")
     else:
         txt_fullfn = None
 

+ 16 - 10
modules/img2img.py

@@ -1,19 +1,15 @@
-import math
 import os
-import sys
-import traceback
 
 import numpy as np
-from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
+from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
 
-from modules import devices, sd_samplers
+from modules import sd_samplers
 from modules.generation_parameters_copypaste import create_override_settings_dict
 from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
 from modules.shared import opts, state
 import modules.shared as shared
 import modules.processing as processing
 from modules.ui import plaintext_to_html
-import modules.images as images
 import modules.scripts
 
 
@@ -46,7 +42,11 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
         if state.interrupted:
             break
 
-        img = Image.open(image)
+        try:
+            img = Image.open(image)
+        except UnidentifiedImageError as e:
+            print(e)
+            continue
         # Use the EXIF orientation of photos taken by smartphones.
         img = ImageOps.exif_transpose(img)
         p.init_images = [img] * p.batch_size
@@ -55,7 +55,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
             # try to find corresponding mask for an image using simple filename matching
             mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
             # if not found use first one ("same mask for all images" use-case)
-            if not mask_image_path in inpaint_masks:
+            if mask_image_path not in inpaint_masks:
                 mask_image_path = inpaint_masks[0]
             mask_image = Image.open(mask_image_path)
             p.image_mask = mask_image
@@ -78,7 +78,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
                 processed_image.save(os.path.join(output_dir, filename))
 
 
-def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
+def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
     override_settings = create_override_settings_dict(override_settings_texts)
 
     is_batch = mode == 5
@@ -114,6 +114,12 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
     if image is not None:
         image = ImageOps.exif_transpose(image)
 
+    if selected_scale_tab == 1:
+        assert image, "Can't scale by because no image is selected"
+
+        width = int(image.width * scale_by)
+        height = int(image.height * scale_by)
+
     assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
 
     p = StableDiffusionProcessingImg2Img(
@@ -151,7 +157,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
         override_settings=override_settings,
     )
 
-    p.scripts = modules.scripts.scripts_txt2img
+    p.scripts = modules.scripts.scripts_img2img
     p.script_args = args
 
     if shared.cmd_opts.enable_console_prompts:

+ 7 - 8
modules/interrogate.py

@@ -11,7 +11,6 @@ import torch.hub
 from torchvision import transforms
 from torchvision.transforms.functional import InterpolationMode
 
-import modules.shared as shared
 from modules import devices, paths, shared, lowvram, modelloader, errors
 
 blip_image_eval_size = 384
@@ -28,11 +27,11 @@ def category_types():
 def download_default_clip_interrogate_categories(content_dir):
     print("Downloading CLIP categories...")
 
-    tmpdir = content_dir + "_tmp"
+    tmpdir = f"{content_dir}_tmp"
     category_types = ["artists", "flavors", "mediums", "movements"]
 
     try:
-        os.makedirs(tmpdir)
+        os.makedirs(tmpdir, exist_ok=True)
         for category_type in category_types:
             torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
         os.rename(tmpdir, content_dir)
@@ -41,7 +40,7 @@ def download_default_clip_interrogate_categories(content_dir):
         errors.display(e, "downloading default CLIP interrogate categories")
     finally:
         if os.path.exists(tmpdir):
-            os.remove(tmpdir)
+            os.removedirs(tmpdir)
 
 
 class InterrogateModels:
@@ -160,7 +159,7 @@ class InterrogateModels:
             text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
 
         top_count = min(top_count, len(text_array))
-        text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
+        text_tokens = clip.tokenize(list(text_array), truncate=True).to(devices.device_interrogate)
         text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
         text_features /= text_features.norm(dim=-1, keepdim=True)
 
@@ -208,13 +207,13 @@ class InterrogateModels:
 
                 image_features /= image_features.norm(dim=-1, keepdim=True)
 
-                for name, topn, items in self.categories():
-                    matches = self.rank(image_features, items, top_count=topn)
+                for cat in self.categories():
+                    matches = self.rank(image_features, cat.items, top_count=cat.topn)
                     for match, score in matches:
                         if shared.opts.interrogate_return_ranks:
                             res += f", ({match}:{score/100:.3f})"
                         else:
-                            res += ", " + match
+                            res += f", {match}"
 
         except Exception:
             print("Error interrogating", file=sys.stderr)

+ 2 - 2
modules/localization.py

@@ -23,7 +23,7 @@ def list_localizations(dirname):
         localizations[fn] = file.path
 
 
-def localization_js(current_localization_name):
+def localization_js(current_localization_name: str) -> str:
     fn = localizations.get(current_localization_name, None)
     data = {}
     if fn is not None:
@@ -34,4 +34,4 @@ def localization_js(current_localization_name):
             print(f"Error loading localization from {fn}:", file=sys.stderr)
             print(traceback.format_exc(), file=sys.stderr)
 
-    return f"var localization = {json.dumps(data)}\n"
+    return f"window.localization = {json.dumps(data)}"

+ 8 - 4
modules/mac_specific.py

@@ -1,6 +1,5 @@
 import torch
 import platform
-from modules import paths
 from modules.sd_hijack_utils import CondFunc
 from packaging import version
 
@@ -43,7 +42,7 @@ if has_mps:
         # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
         CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
                                                           lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
-        # MPS workaround for https://github.com/pytorch/pytorch/issues/80800 
+        # MPS workaround for https://github.com/pytorch/pytorch/issues/80800
         CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
                                                                                         lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
         # MPS workaround for https://github.com/pytorch/pytorch/issues/90532
@@ -54,6 +53,11 @@ if has_mps:
         CondFunc('torch.cumsum', cumsum_fix_func, None)
         CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
         CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
-    if version.parse(torch.__version__) == version.parse("2.0"):
+
         # MPS workaround for https://github.com/pytorch/pytorch/issues/96113
-        CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda *args, **kwargs: len(args) == 6)
+        CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
+
+        # MPS workaround for https://github.com/pytorch/pytorch/issues/92311
+        if platform.processor() == 'i386':
+            for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
+                CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')

+ 1 - 1
modules/masking.py

@@ -4,7 +4,7 @@ from PIL import Image, ImageFilter, ImageOps
 def get_crop_region(mask, pad=0):
     """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
     For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
-    
+
     h, w = mask.shape
 
     crop_left = 0

+ 21 - 43
modules/modelloader.py

@@ -1,4 +1,3 @@
-import glob
 import os
 import shutil
 import importlib
@@ -22,9 +21,6 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
     """
     output = []
 
-    if ext_filter is None:
-        ext_filter = []
-
     try:
         places = []
 
@@ -39,22 +35,14 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
         places.append(model_path)
 
         for place in places:
-            if os.path.exists(place):
-                for file in glob.iglob(place + '**/**', recursive=True):
-                    full_path = file
-                    if os.path.isdir(full_path):
-                        continue
-                    if os.path.islink(full_path) and not os.path.exists(full_path):
-                        print(f"Skipping broken symlink: {full_path}")
-                        continue
-                    if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
-                        continue
-                    if len(ext_filter) != 0:
-                        model_name, extension = os.path.splitext(file)
-                        if extension not in ext_filter:
-                            continue
-                    if file not in output:
-                        output.append(full_path)
+            for full_path in shared.walk_files(place, allowed_extensions=ext_filter):
+                if os.path.islink(full_path) and not os.path.exists(full_path):
+                    print(f"Skipping broken symlink: {full_path}")
+                    continue
+                if ext_blacklist is not None and any(full_path.endswith(x) for x in ext_blacklist):
+                    continue
+                if full_path not in output:
+                    output.append(full_path)
 
         if model_url is not None and len(output) == 0:
             if download_name is not None:
@@ -119,32 +107,15 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
                     print(f"Moving {file} from {src_path} to {dest_path}.")
                     try:
                         shutil.move(fullpath, dest_path)
-                    except:
+                    except Exception:
                         pass
             if len(os.listdir(src_path)) == 0:
                 print(f"Removing empty folder: {src_path}")
                 shutil.rmtree(src_path, True)
-    except:
+    except Exception:
         pass
 
 
-builtin_upscaler_classes = []
-forbidden_upscaler_classes = set()
-
-
-def list_builtin_upscalers():
-    load_upscalers()
-
-    builtin_upscaler_classes.clear()
-    builtin_upscaler_classes.extend(Upscaler.__subclasses__())
-
-
-def forbid_loaded_nonbuiltin_upscalers():
-    for cls in Upscaler.__subclasses__():
-        if cls not in builtin_upscaler_classes:
-            forbidden_upscaler_classes.add(cls)
-
-
 def load_upscalers():
     # We can only do this 'magic' method to dynamically load upscalers if they are referenced,
     # so we'll try to import any _model.py files before looking in __subclasses__
@@ -155,15 +126,22 @@ def load_upscalers():
             full_model = f"modules.{model_name}_model"
             try:
                 importlib.import_module(full_model)
-            except:
+            except Exception:
                 pass
 
     datas = []
     commandline_options = vars(shared.cmd_opts)
-    for cls in Upscaler.__subclasses__():
-        if cls in forbidden_upscaler_classes:
-            continue
 
+    # some of upscaler classes will not go away after reloading their modules, and we'll end
+    # up with two copies of those classes. The newest copy will always be the last in the list,
+    # so we go from end to beginning and ignore duplicates
+    used_classes = {}
+    for cls in reversed(Upscaler.__subclasses__()):
+        classname = str(cls)
+        if classname not in used_classes:
+            used_classes[classname] = cls
+
+    for cls in reversed(used_classes.values()):
         name = cls.__name__
         cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
         scaler = cls(commandline_options.get(cmd_name, None))

+ 26 - 30
modules/models/diffusion/ddpm_edit.py

@@ -52,7 +52,7 @@ class DDPM(pl.LightningModule):
                  beta_schedule="linear",
                  loss_type="l2",
                  ckpt_path=None,
-                 ignore_keys=[],
+                 ignore_keys=None,
                  load_only_unet=False,
                  monitor="val/loss",
                  use_ema=True,
@@ -107,7 +107,7 @@ class DDPM(pl.LightningModule):
             print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
 
         if ckpt_path is not None:
-            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
+            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
 
             # If initialing from EMA-only checkpoint, create EMA model after loading.
             if self.use_ema and not load_ema:
@@ -194,7 +194,9 @@ class DDPM(pl.LightningModule):
                 if context is not None:
                     print(f"{context}: Restored training weights")
 
-    def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+    def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
+        ignore_keys = ignore_keys or []
+
         sd = torch.load(path, map_location="cpu")
         if "state_dict" in list(sd.keys()):
             sd = sd["state_dict"]
@@ -223,7 +225,7 @@ class DDPM(pl.LightningModule):
         for k in keys:
             for ik in ignore_keys:
                 if k.startswith(ik):
-                    print("Deleting key {} from state_dict.".format(k))
+                    print(f"Deleting key {k} from state_dict.")
                     del sd[k]
         missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
             sd, strict=False)
@@ -386,7 +388,7 @@ class DDPM(pl.LightningModule):
         _, loss_dict_no_ema = self.shared_step(batch)
         with self.ema_scope():
             _, loss_dict_ema = self.shared_step(batch)
-            loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
+            loss_dict_ema = {f"{key}_ema": loss_dict_ema[key] for key in loss_dict_ema}
         self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
         self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
 
@@ -403,7 +405,7 @@ class DDPM(pl.LightningModule):
 
     @torch.no_grad()
     def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
-        log = dict()
+        log = {}
         x = self.get_input(batch, self.first_stage_key)
         N = min(x.shape[0], N)
         n_row = min(x.shape[0], n_row)
@@ -411,7 +413,7 @@ class DDPM(pl.LightningModule):
         log["inputs"] = x
 
         # get diffusion row
-        diffusion_row = list()
+        diffusion_row = []
         x_start = x[:n_row]
 
         for t in range(self.num_timesteps):
@@ -473,13 +475,13 @@ class LatentDiffusion(DDPM):
             conditioning_key = None
         ckpt_path = kwargs.pop("ckpt_path", None)
         ignore_keys = kwargs.pop("ignore_keys", [])
-        super().__init__(conditioning_key=conditioning_key, *args, load_ema=load_ema, **kwargs)
+        super().__init__(*args, conditioning_key=conditioning_key, load_ema=load_ema, **kwargs)
         self.concat_mode = concat_mode
         self.cond_stage_trainable = cond_stage_trainable
         self.cond_stage_key = cond_stage_key
         try:
             self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
-        except:
+        except Exception:
             self.num_downs = 0
         if not scale_by_std:
             self.scale_factor = scale_factor
@@ -891,16 +893,6 @@ class LatentDiffusion(DDPM):
                 c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
         return self.p_losses(x, c, t, *args, **kwargs)
 
-    def _rescale_annotations(self, bboxes, crop_coordinates):  # TODO: move to dataset
-        def rescale_bbox(bbox):
-            x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
-            y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
-            w = min(bbox[2] / crop_coordinates[2], 1 - x0)
-            h = min(bbox[3] / crop_coordinates[3], 1 - y0)
-            return x0, y0, w, h
-
-        return [rescale_bbox(b) for b in bboxes]
-
     def apply_model(self, x_noisy, t, cond, return_ids=False):
 
         if isinstance(cond, dict):
@@ -1140,7 +1132,7 @@ class LatentDiffusion(DDPM):
         if cond is not None:
             if isinstance(cond, dict):
                 cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
-                list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+                [x[:batch_size] for x in cond[key]] for key in cond}
             else:
                 cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
 
@@ -1171,8 +1163,10 @@ class LatentDiffusion(DDPM):
 
             if i % log_every_t == 0 or i == timesteps - 1:
                 intermediates.append(x0_partial)
-            if callback: callback(i)
-            if img_callback: img_callback(img, i)
+            if callback:
+                callback(i)
+            if img_callback:
+                img_callback(img, i)
         return img, intermediates
 
     @torch.no_grad()
@@ -1219,8 +1213,10 @@ class LatentDiffusion(DDPM):
 
             if i % log_every_t == 0 or i == timesteps - 1:
                 intermediates.append(img)
-            if callback: callback(i)
-            if img_callback: img_callback(img, i)
+            if callback:
+                callback(i)
+            if img_callback:
+                img_callback(img, i)
 
         if return_intermediates:
             return img, intermediates
@@ -1235,7 +1231,7 @@ class LatentDiffusion(DDPM):
         if cond is not None:
             if isinstance(cond, dict):
                 cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
-                list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+                [x[:batch_size] for x in cond[key]] for key in cond}
             else:
                 cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
         return self.p_sample_loop(cond,
@@ -1267,7 +1263,7 @@ class LatentDiffusion(DDPM):
 
         use_ddim = False
 
-        log = dict()
+        log = {}
         z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
                                            return_first_stage_outputs=True,
                                            force_c_encode=True,
@@ -1295,7 +1291,7 @@ class LatentDiffusion(DDPM):
 
         if plot_diffusion_rows:
             # get diffusion row
-            diffusion_row = list()
+            diffusion_row = []
             z_start = z[:n_row]
             for t in range(self.num_timesteps):
                 if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
@@ -1337,7 +1333,7 @@ class LatentDiffusion(DDPM):
 
             if inpaint:
                 # make a simple center square
-                b, h, w = z.shape[0], z.shape[2], z.shape[3]
+                h, w = z.shape[2], z.shape[3]
                 mask = torch.ones(N, h, w).to(self.device)
                 # zeros will be filled in
                 mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
@@ -1439,10 +1435,10 @@ class Layout2ImgDiffusion(LatentDiffusion):
     # TODO: move all layout-specific hacks to this class
     def __init__(self, cond_stage_key, *args, **kwargs):
         assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
-        super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
+        super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
 
     def log_images(self, batch, N=8, *args, **kwargs):
-        logs = super().log_images(batch=batch, N=N, *args, **kwargs)
+        logs = super().log_images(*args, batch=batch, N=N, **kwargs)
 
         key = 'train' if self.training else 'validation'
         dset = self.trainer.datamodule.datasets[key]

+ 1 - 1
modules/models/diffusion/uni_pc/__init__.py

@@ -1 +1 @@
-from .sampler import UniPCSampler
+from .sampler import UniPCSampler  # noqa: F401

+ 2 - 1
modules/models/diffusion/uni_pc/sampler.py

@@ -54,7 +54,8 @@ class UniPCSampler(object):
         if conditioning is not None:
             if isinstance(conditioning, dict):
                 ctmp = conditioning[list(conditioning.keys())[0]]
-                while isinstance(ctmp, list): ctmp = ctmp[0]
+                while isinstance(ctmp, list):
+                    ctmp = ctmp[0]
                 cbs = ctmp.shape[0]
                 if cbs != batch_size:
                     print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")

+ 46 - 40
modules/models/diffusion/uni_pc/uni_pc.py

@@ -1,7 +1,6 @@
 import torch
-import torch.nn.functional as F
 import math
-from tqdm.auto import trange
+import tqdm
 
 
 class NoiseScheduleVP:
@@ -94,7 +93,7 @@ class NoiseScheduleVP:
         """
 
         if schedule not in ['discrete', 'linear', 'cosine']:
-            raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule))
+            raise ValueError(f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear' or 'cosine'")
 
         self.schedule = schedule
         if schedule == 'discrete':
@@ -179,13 +178,13 @@ def model_wrapper(
     model,
     noise_schedule,
     model_type="noise",
-    model_kwargs={},
+    model_kwargs=None,
     guidance_type="uncond",
     #condition=None,
     #unconditional_condition=None,
     guidance_scale=1.,
     classifier_fn=None,
-    classifier_kwargs={},
+    classifier_kwargs=None,
 ):
     """Create a wrapper function for the noise prediction model.
 
@@ -276,6 +275,9 @@ def model_wrapper(
         A noise prediction model that accepts the noised data and the continuous time as the inputs.
     """
 
+    model_kwargs = model_kwargs or {}
+    classifier_kwargs = classifier_kwargs or {}
+
     def get_model_input_time(t_continuous):
         """
         Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
@@ -342,7 +344,7 @@ def model_wrapper(
                 t_in = torch.cat([t_continuous] * 2)
                 if isinstance(condition, dict):
                     assert isinstance(unconditional_condition, dict)
-                    c_in = dict()
+                    c_in = {}
                     for k in condition:
                         if isinstance(condition[k], list):
                             c_in[k] = [torch.cat([
@@ -353,7 +355,7 @@ def model_wrapper(
                                 unconditional_condition[k],
                                 condition[k]])
                 elif isinstance(condition, list):
-                    c_in = list()
+                    c_in = []
                     assert isinstance(unconditional_condition, list)
                     for i in range(len(condition)):
                         c_in.append(torch.cat([unconditional_condition[i], condition[i]]))
@@ -469,7 +471,7 @@ class UniPC:
             t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
             return t
         else:
-            raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
+            raise ValueError(f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'")
 
     def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
         """
@@ -757,40 +759,44 @@ class UniPC:
                 vec_t = timesteps[0].expand((x.shape[0]))
                 model_prev_list = [self.model_fn(x, vec_t)]
                 t_prev_list = [vec_t]
-                # Init the first `order` values by lower order multistep DPM-Solver.
-                for init_order in range(1, order):
-                    vec_t = timesteps[init_order].expand(x.shape[0])
-                    x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
-                    if model_x is None:
-                        model_x = self.model_fn(x, vec_t)
-                    if self.after_update is not None:
-                        self.after_update(x, model_x)
-                    model_prev_list.append(model_x)
-                    t_prev_list.append(vec_t)
-                for step in trange(order, steps + 1):
-                    vec_t = timesteps[step].expand(x.shape[0])
-                    if lower_order_final:
-                        step_order = min(order, steps + 1 - step)
-                    else:
-                        step_order = order
-                    #print('this step order:', step_order)
-                    if step == steps:
-                        #print('do not run corrector at the last step')
-                        use_corrector = False
-                    else:
-                        use_corrector = True
-                    x, model_x =  self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
-                    if self.after_update is not None:
-                        self.after_update(x, model_x)
-                    for i in range(order - 1):
-                        t_prev_list[i] = t_prev_list[i + 1]
-                        model_prev_list[i] = model_prev_list[i + 1]
-                    t_prev_list[-1] = vec_t
-                    # We do not need to evaluate the final model value.
-                    if step < steps:
+                with tqdm.tqdm(total=steps) as pbar:
+                    # Init the first `order` values by lower order multistep DPM-Solver.
+                    for init_order in range(1, order):
+                        vec_t = timesteps[init_order].expand(x.shape[0])
+                        x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
                         if model_x is None:
                             model_x = self.model_fn(x, vec_t)
-                        model_prev_list[-1] = model_x
+                        if self.after_update is not None:
+                            self.after_update(x, model_x)
+                        model_prev_list.append(model_x)
+                        t_prev_list.append(vec_t)
+                        pbar.update()
+
+                    for step in range(order, steps + 1):
+                        vec_t = timesteps[step].expand(x.shape[0])
+                        if lower_order_final:
+                            step_order = min(order, steps + 1 - step)
+                        else:
+                            step_order = order
+                        #print('this step order:', step_order)
+                        if step == steps:
+                            #print('do not run corrector at the last step')
+                            use_corrector = False
+                        else:
+                            use_corrector = True
+                        x, model_x =  self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
+                        if self.after_update is not None:
+                            self.after_update(x, model_x)
+                        for i in range(order - 1):
+                            t_prev_list[i] = t_prev_list[i + 1]
+                            model_prev_list[i] = model_prev_list[i + 1]
+                        t_prev_list[-1] = vec_t
+                        # We do not need to evaluate the final model value.
+                        if step < steps:
+                            if model_x is None:
+                                model_x = self.model_fn(x, vec_t)
+                            model_prev_list[-1] = model_x
+                        pbar.update()
         else:
             raise NotImplementedError()
         if denoise_to_zero:

+ 14 - 2
modules/ngrok.py

@@ -7,12 +7,24 @@ def connect(token, port, region):
     else:
         if ':' in token:
             # token = authtoken:username:password
-            account = token.split(':')[1] + ':' + token.split(':')[-1]
-            token = token.split(':')[0]
+            token, username, password = token.split(':', 2)
+            account = f"{username}:{password}"
 
     config = conf.PyngrokConfig(
         auth_token=token, region=region
     )
+
+    # Guard for existing tunnels
+    existing = ngrok.get_tunnels(pyngrok_config=config)
+    if existing:
+        for established in existing:
+            # Extra configuration in the case that the user is also using ngrok for other tunnels
+            if established.config['addr'][-4:] == str(port):
+                public_url = existing[0].public_url
+                print(f'ngrok has already been connected to localhost:{port}! URL: {public_url}\n'
+                    'You can use this link after the launch is complete.')
+                return
+
     try:
         if account is None:
             public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url

+ 3 - 3
modules/paths.py

@@ -1,8 +1,8 @@
 import os
 import sys
-from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir
+from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir  # noqa: F401
 
-import modules.safe
+import modules.safe  # noqa: F401
 
 
 # data_path = cmd_opts_pre.data
@@ -16,7 +16,7 @@ for possible_sd_path in possible_sd_paths:
         sd_path = os.path.abspath(possible_sd_path)
         break
 
-assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths)
+assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"
 
 path_dirs = [
     (sd_path, 'ldm', 'Stable Diffusion', []),

+ 11 - 2
modules/paths_internal.py

@@ -2,8 +2,14 @@
 
 import argparse
 import os
+import sys
+import shlex
 
-script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
+sys.argv += shlex.split(commandline_args)
+
+modules_path = os.path.dirname(os.path.realpath(__file__))
+script_path = os.path.dirname(modules_path)
 
 sd_configs_path = os.path.join(script_path, "configs")
 sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
@@ -12,7 +18,7 @@ default_sd_model_file = sd_model_file
 
 # Parse the --data-dir flag first so we can use it as a base for our other argument default values
 parser_pre = argparse.ArgumentParser(add_help=False)
-parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
+parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(modules_path), help="base path where all user data is stored", )
 cmd_opts_pre = parser_pre.parse_known_args()[0]
 
 data_path = cmd_opts_pre.data_dir
@@ -20,3 +26,6 @@ data_path = cmd_opts_pre.data_dir
 models_path = os.path.join(data_path, "models")
 extensions_dir = os.path.join(data_path, "extensions")
 extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
+config_states_dir = os.path.join(script_path, "config_states")
+
+roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf')

+ 7 - 2
modules/postprocessing.py

@@ -18,9 +18,14 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
 
     if extras_mode == 1:
         for img in image_folder:
-            image = Image.open(img)
+            if isinstance(img, Image.Image):
+                image = img
+                fn = ''
+            else:
+                image = Image.open(os.path.abspath(img.name))
+                fn = os.path.splitext(img.orig_name)[0]
             image_data.append(image)
-            image_names.append(os.path.splitext(img.orig_name)[0])
+            image_names.append(fn)
     elif extras_mode == 2:
         assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
         assert input_dir, 'input directory not selected'

+ 94 - 16
modules/processing.py

@@ -2,7 +2,7 @@ import json
 import math
 import os
 import sys
-import warnings
+import hashlib
 
 import torch
 import numpy as np
@@ -10,10 +10,10 @@ from PIL import Image, ImageFilter, ImageOps
 import random
 import cv2
 from skimage import exposure
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List
 
 import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common
 from modules.sd_hijack import model_hijack
 from modules.shared import opts, cmd_opts, state
 import modules.shared as shared
@@ -30,6 +30,7 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
 from einops import repeat, rearrange
 from blendmodes.blend import blendLayers, BlendType
 
+
 # some of those options should not be changed at all because they would break the model, so I removed them from options.
 opt_C = 4
 opt_f = 8
@@ -105,7 +106,7 @@ class StableDiffusionProcessing:
     """
     The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
     """
-    def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
+    def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
         if sampler_index is not None:
             print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
 
@@ -140,6 +141,7 @@ class StableDiffusionProcessing:
         self.denoising_strength: float = denoising_strength
         self.sampler_noise_scheduler_override = None
         self.ddim_discretize = ddim_discretize or opts.ddim_discretize
+        self.s_min_uncond = s_min_uncond or opts.s_min_uncond
         self.s_churn = s_churn or opts.s_churn
         self.s_tmin = s_tmin or opts.s_tmin
         self.s_tmax = s_tmax or float('inf')  # not representable as a standard ui option
@@ -148,6 +150,8 @@ class StableDiffusionProcessing:
         self.override_settings_restore_afterwards = override_settings_restore_afterwards
         self.is_using_inpainting_conditioning = False
         self.disable_extra_networks = False
+        self.token_merging_ratio = 0
+        self.token_merging_ratio_hr = 0
 
         if not seed_enable_extras:
             self.subseed = -1
@@ -162,6 +166,9 @@ class StableDiffusionProcessing:
         self.all_seeds = None
         self.all_subseeds = None
         self.iteration = 0
+        self.is_hr_pass = False
+        self.sampler = None
+
 
     @property
     def sd_model(self):
@@ -270,6 +277,12 @@ class StableDiffusionProcessing:
     def close(self):
         self.sampler = None
 
+    def get_token_merging_ratio(self, for_hr=False):
+        if for_hr:
+            return self.token_merging_ratio_hr or opts.token_merging_ratio_hr or self.token_merging_ratio or opts.token_merging_ratio
+
+        return self.token_merging_ratio or opts.token_merging_ratio
+
 
 class Processed:
     def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
@@ -299,6 +312,8 @@ class Processed:
         self.styles = p.styles
         self.job_timestamp = state.job_timestamp
         self.clip_skip = opts.CLIP_stop_at_last_layers
+        self.token_merging_ratio = p.token_merging_ratio
+        self.token_merging_ratio_hr = p.token_merging_ratio_hr
 
         self.eta = p.eta
         self.ddim_discretize = p.ddim_discretize
@@ -306,6 +321,7 @@ class Processed:
         self.s_tmin = p.s_tmin
         self.s_tmax = p.s_tmax
         self.s_noise = p.s_noise
+        self.s_min_uncond = p.s_min_uncond
         self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
         self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
         self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
@@ -356,6 +372,9 @@ class Processed:
     def infotext(self, p: StableDiffusionProcessing, index):
         return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
 
+    def get_token_merging_ratio(self, for_hr=False):
+        return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio
+
 
 # from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
 def slerp(val, low, high):
@@ -454,10 +473,27 @@ def fix_seed(p):
     p.subseed = get_fixed_seed(p.subseed)
 
 
+def program_version():
+    import launch
+
+    res = launch.git_tag()
+    if res == "<none>":
+        res = None
+
+    return res
+
+
 def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
     index = position_in_batch + iteration * p.batch_size
 
     clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
+    enable_hr = getattr(p, 'enable_hr', False)
+    token_merging_ratio = p.get_token_merging_ratio()
+    token_merging_ratio_hr = p.get_token_merging_ratio(for_hr=True)
+
+    uses_ensd = opts.eta_noise_seed_delta != 0
+    if uses_ensd:
+        uses_ensd = sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p)
 
     generation_params = {
         "Steps": p.steps,
@@ -475,14 +511,19 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
         "Denoising strength": getattr(p, 'denoising_strength', None),
         "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
         "Clip skip": None if clip_skip <= 1 else clip_skip,
-        "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
+        "ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
+        "Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
+        "Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
+        "Init image hash": getattr(p, 'init_img_hash', None),
+        "RNG": opts.randn_source if opts.randn_source != "GPU" else None,
+        "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
+        **p.extra_generation_params,
+        "Version": program_version() if opts.add_version_to_infotext else None,
     }
 
-    generation_params.update(p.extra_generation_params)
-
     generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
 
-    negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else ""
+    negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
 
     return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
 
@@ -491,6 +532,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
     stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
 
     try:
+        # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
+        if sd_models.checkpoint_alisases.get(p.override_settings.get('sd_model_checkpoint')) is None:
+            p.override_settings.pop('sd_model_checkpoint', None)
+            sd_models.reload_model_weights()
+
         for k, v in p.override_settings.items():
             setattr(opts, k, v)
 
@@ -500,15 +546,17 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
             if k == 'sd_vae':
                 sd_vae.reload_vae_weights()
 
+        sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
+
         res = process_images_inner(p)
 
     finally:
+        sd_models.apply_token_merging(p.sd_model, 0)
+
         # restore opts to original state
         if p.override_settings_restore_afterwards:
             for k, v in stored_opts.items():
                 setattr(opts, k, v)
-                if k == 'sd_model_checkpoint':
-                    sd_models.reload_model_weights()
 
                 if k == 'sd_vae':
                     sd_vae.reload_vae_weights()
@@ -639,8 +687,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
                     processed = Processed(p, [], p.seed, "")
                     file.write(processed.infotext(p, 0))
 
-            uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, cached_uc)
-            c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, cached_c)
+            sampler_config = sd_samplers.find_sampler_config(p.sampler_name)
+            step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
+            uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps * step_multiplier, cached_uc)
+            c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps * step_multiplier, cached_c)
 
             if len(model_hijack.comments) > 0:
                 for comment in model_hijack.comments:
@@ -670,6 +720,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
                 p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
 
             for i, x_sample in enumerate(x_samples_ddim):
+                p.batch_index = i
+
                 x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
                 x_sample = x_sample.astype(np.uint8)
 
@@ -706,9 +758,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
                     image.info["parameters"] = text
                 output_images.append(image)
 
-                if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
+                if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
                     image_mask = p.mask_for_overlay.convert('RGB')
-                    image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), p.mask_for_overlay.convert('L')).convert('RGBA')
+                    image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
 
                     if opts.save_mask:
                         images.save_image(image_mask, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
@@ -718,7 +770,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
 
                     if opts.return_mask:
                         output_images.append(image_mask)
-                    
+
                     if opts.return_mask_composite:
                         output_images.append(image_mask_composite)
 
@@ -751,7 +803,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
 
     devices.torch_gc()
 
-    res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
+    res = Processed(
+        p,
+        images_list=output_images,
+        seed=p.all_seeds[0],
+        info=infotext(),
+        comments="".join(f"\n\n{comment}" for comment in comments),
+        subseed=p.all_subseeds[0],
+        index_of_first_image=index_of_first_image,
+        infotexts=infotexts,
+    )
 
     if p.scripts is not None:
         p.scripts.postprocess(p, res)
@@ -871,6 +932,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
         if not self.enable_hr:
             return samples
 
+        self.is_hr_pass = True
+
         target_width = self.hr_upscale_to_x
         target_height = self.hr_upscale_to_y
 
@@ -938,8 +1001,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
         x = None
         devices.torch_gc()
 
+        sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
+
         samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
 
+        sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
+
+        self.is_hr_pass = False
+
         return samples
 
 
@@ -1007,6 +1076,12 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
             self.color_corrections = []
         imgs = []
         for img in self.init_images:
+
+            # Save init image
+            if opts.save_init_img:
+                self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
+                images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)
+
             image = images.flatten(img, opts.img2img_background_color)
 
             if crop_region is None and self.resize_mode != 3:
@@ -1093,3 +1168,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
         devices.torch_gc()
 
         return samples
+
+    def get_token_merging_ratio(self, for_hr=False):
+        return self.token_merging_ratio or ("token_merging_ratio" in self.override_settings and opts.token_merging_ratio) or opts.token_merging_ratio_img2img or opts.token_merging_ratio

+ 32 - 2
modules/progress.py

@@ -13,6 +13,8 @@ import modules.shared as shared
 current_task = None
 pending_tasks = {}
 finished_tasks = []
+recorded_results = []
+recorded_results_limit = 2
 
 
 def start_task(id_task):
@@ -33,6 +35,12 @@ def finish_task(id_task):
         finished_tasks.pop(0)
 
 
+def record_results(id_task, res):
+    recorded_results.append((id_task, res))
+    if len(recorded_results) > recorded_results_limit:
+        recorded_results.pop(0)
+
+
 def add_task_to_queue(id_job):
     pending_tasks[id_job] = time.time()
 
@@ -87,8 +95,20 @@ def progressapi(req: ProgressRequest):
         image = shared.state.current_image
         if image is not None:
             buffered = io.BytesIO()
-            image.save(buffered, format="png")
-            live_preview = 'data:image/png;base64,' + base64.b64encode(buffered.getvalue()).decode("ascii")
+
+            if opts.live_previews_image_format == "png":
+                # using optimize for large images takes an enormous amount of time
+                if max(*image.size) <= 256:
+                    save_kwargs = {"optimize": True}
+                else:
+                    save_kwargs = {"optimize": False, "compress_level": 1}
+
+            else:
+                save_kwargs = {}
+
+            image.save(buffered, format=opts.live_previews_image_format, **save_kwargs)
+            base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
+            live_preview = f"data:image/{opts.live_previews_image_format};base64,{base64_image}"
             id_live_preview = shared.state.id_live_preview
         else:
             live_preview = None
@@ -97,3 +117,13 @@ def progressapi(req: ProgressRequest):
 
     return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
 
+
+def restore_progress(id_task):
+    while id_task == current_task or id_task in pending_tasks:
+        time.sleep(0.1)
+
+    res = next(iter([x[1] for x in recorded_results if id_task == x[0]]), None)
+    if res is not None:
+        return res
+
+    return gr.update(), gr.update(), gr.update(), f"Couldn't restore progress for {id_task}: results either have been discarded or never were obtained"

+ 15 - 12
modules/prompt_parser.py

@@ -54,18 +54,21 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
     """
 
     def collect_steps(steps, tree):
-        l = [steps]
+        res = [steps]
+
         class CollectSteps(lark.Visitor):
             def scheduled(self, tree):
                 tree.children[-1] = float(tree.children[-1])
                 if tree.children[-1] < 1:
                     tree.children[-1] *= steps
                 tree.children[-1] = min(steps, int(tree.children[-1]))
-                l.append(tree.children[-1])
+                res.append(tree.children[-1])
+
             def alternate(self, tree):
-                l.extend(range(1, steps+1))
+                res.extend(range(1, steps+1))
+
         CollectSteps().visit(tree)
-        return sorted(set(l))
+        return sorted(set(res))
 
     def at_step(step, tree):
         class AtStep(lark.Transformer):
@@ -92,7 +95,7 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
     def get_schedule(prompt):
         try:
             tree = schedule_parser.parse(prompt)
-        except lark.exceptions.LarkError as e:
+        except lark.exceptions.LarkError:
             if 0:
                 import traceback
                 traceback.print_exc()
@@ -140,7 +143,7 @@ def get_learned_conditioning(model, prompts, steps):
         conds = model.get_learned_conditioning(texts)
 
         cond_schedule = []
-        for i, (end_at_step, text) in enumerate(prompt_schedule):
+        for i, (end_at_step, _) in enumerate(prompt_schedule):
             cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
 
         cache[prompt] = cond_schedule
@@ -216,8 +219,8 @@ def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_s
     res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
     for i, cond_schedule in enumerate(c):
         target_index = 0
-        for current, (end_at, cond) in enumerate(cond_schedule):
-            if current_step <= end_at:
+        for current, entry in enumerate(cond_schedule):
+            if current_step <= entry.end_at_step:
                 target_index = current
                 break
         res[i] = cond_schedule[target_index].cond
@@ -231,13 +234,13 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
     tensors = []
     conds_list = []
 
-    for batch_no, composable_prompts in enumerate(c.batch):
+    for composable_prompts in c.batch:
         conds_for_batch = []
 
-        for cond_index, composable_prompt in enumerate(composable_prompts):
+        for composable_prompt in composable_prompts:
             target_index = 0
-            for current, (end_at, cond) in enumerate(composable_prompt.schedules):
-                if current_step <= end_at:
+            for current, entry in enumerate(composable_prompt.schedules):
+                if current_step <= entry.end_at_step:
                     target_index = current
                     break
 

+ 17 - 7
modules/realesrgan_model.py

@@ -9,7 +9,7 @@ from realesrgan import RealESRGANer
 
 from modules.upscaler import Upscaler, UpscalerData
 from modules.shared import cmd_opts, opts
-
+from modules import modelloader
 
 class UpscalerRealESRGAN(Upscaler):
     def __init__(self, path):
@@ -17,13 +17,21 @@ class UpscalerRealESRGAN(Upscaler):
         self.user_path = path
         super().__init__()
         try:
-            from basicsr.archs.rrdbnet_arch import RRDBNet
-            from realesrgan import RealESRGANer
-            from realesrgan.archs.srvgg_arch import SRVGGNetCompact
+            from basicsr.archs.rrdbnet_arch import RRDBNet  # noqa: F401
+            from realesrgan import RealESRGANer  # noqa: F401
+            from realesrgan.archs.srvgg_arch import SRVGGNetCompact  # noqa: F401
             self.enable = True
             self.scalers = []
             scalers = self.load_models(path)
+
+            local_model_paths = self.find_models(ext_filter=[".pth"])
             for scaler in scalers:
+                if scaler.local_data_path.startswith("http"):
+                    filename = modelloader.friendly_name(scaler.local_data_path)
+                    local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")]
+                    if local_model_candidates:
+                        scaler.local_data_path = local_model_candidates[0]
+
                 if scaler.name in opts.realesrgan_enabled_models:
                     self.scalers.append(scaler)
 
@@ -39,7 +47,7 @@ class UpscalerRealESRGAN(Upscaler):
 
         info = self.load_model(path)
         if not os.path.exists(info.local_data_path):
-            print("Unable to load RealESRGAN model: %s" % info.name)
+            print(f"Unable to load RealESRGAN model: {info.name}")
             return img
 
         upsampler = RealESRGANer(
@@ -64,7 +72,9 @@ class UpscalerRealESRGAN(Upscaler):
                 print(f"Unable to find model info: {path}")
                 return None
 
-            info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
+            if info.local_data_path.startswith("http"):
+                info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
+
             return info
         except Exception as e:
             print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
@@ -124,6 +134,6 @@ def get_realesrgan_models(scaler):
             ),
         ]
         return models
-    except Exception as e:
+    except Exception:
         print("Error making Real-ESRGAN models list:", file=sys.stderr)
         print(traceback.format_exc(), file=sys.stderr)

+ 9 - 8
modules/safe.py

@@ -1,6 +1,5 @@
 # this code is adapted from the script contributed by anon from /h/
 
-import io
 import pickle
 import collections
 import sys
@@ -12,11 +11,9 @@ import _codecs
 import zipfile
 import re
 
-
 # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
 TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
 
-
 def encode(*args):
     out = _codecs.encode(*args)
     return out
@@ -27,7 +24,11 @@ class RestrictedUnpickler(pickle.Unpickler):
 
     def persistent_load(self, saved_id):
         assert saved_id[0] == 'storage'
-        return TypedStorage()
+
+        try:
+            return TypedStorage(_internal=True)
+        except TypeError:
+            return TypedStorage()  # PyTorch before 2.0 does not have the _internal argument
 
     def find_class(self, module, name):
         if self.extra_handler is not None:
@@ -39,7 +40,7 @@ class RestrictedUnpickler(pickle.Unpickler):
             return getattr(collections, name)
         if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
             return getattr(torch._utils, name)
-        if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']:
+        if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']:
             return getattr(torch, name)
         if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
             return getattr(torch.nn.modules.container, name)
@@ -94,16 +95,16 @@ def check_pt(filename, extra_handler):
 
     except zipfile.BadZipfile:
 
-        # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
+        # if it's not a zip file, it's an old pytorch format, with five objects written to pickle
         with open(filename, "rb") as file:
             unpickler = RestrictedUnpickler(file)
             unpickler.extra_handler = extra_handler
-            for i in range(5):
+            for _ in range(5):
                 unpickler.load()
 
 
 def load(filename, *args, **kwargs):
-    return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
+    return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
 
 
 def load_with_extra(filename, extra_handler=None, *args, **kwargs):

+ 53 - 7
modules/script_callbacks.py

@@ -32,27 +32,42 @@ class CFGDenoiserParams:
     def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
         self.x = x
         """Latent image representation in the process of being denoised"""
-        
+
         self.image_cond = image_cond
         """Conditioning image"""
-        
+
         self.sigma = sigma
         """Current sigma noise step value"""
-        
+
         self.sampling_step = sampling_step
         """Current Sampling step number"""
-        
+
         self.total_sampling_steps = total_sampling_steps
         """Total number of sampling steps planned"""
-        
+
         self.text_cond = text_cond
         """ Encoder hidden states of text conditioning from prompt"""
-        
+
         self.text_uncond = text_uncond
         """ Encoder hidden states of text conditioning from negative prompt"""
 
 
 class CFGDenoisedParams:
+    def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
+        self.x = x
+        """Latent image representation in the process of being denoised"""
+
+        self.sampling_step = sampling_step
+        """Current Sampling step number"""
+
+        self.total_sampling_steps = total_sampling_steps
+        """Total number of sampling steps planned"""
+
+        self.inner_model = inner_model
+        """Inner model reference used for denoising"""
+
+
+class AfterCFGCallbackParams:
     def __init__(self, x, sampling_step, total_sampling_steps):
         self.x = x
         """Latent image representation in the process of being denoised"""
@@ -87,12 +102,14 @@ callback_map = dict(
     callbacks_image_saved=[],
     callbacks_cfg_denoiser=[],
     callbacks_cfg_denoised=[],
+    callbacks_cfg_after_cfg=[],
     callbacks_before_component=[],
     callbacks_after_component=[],
     callbacks_image_grid=[],
     callbacks_infotext_pasted=[],
     callbacks_script_unloaded=[],
     callbacks_before_ui=[],
+    callbacks_on_reload=[],
 )
 
 
@@ -109,6 +126,14 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI):
             report_exception(c, 'app_started_callback')
 
 
+def app_reload_callback():
+    for c in callback_map['callbacks_on_reload']:
+        try:
+            c.callback()
+        except Exception:
+            report_exception(c, 'callbacks_on_reload')
+
+
 def model_loaded_callback(sd_model):
     for c in callback_map['callbacks_model_loaded']:
         try:
@@ -177,6 +202,14 @@ def cfg_denoised_callback(params: CFGDenoisedParams):
             report_exception(c, 'cfg_denoised_callback')
 
 
+def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
+    for c in callback_map['callbacks_cfg_after_cfg']:
+        try:
+            c.callback(params)
+        except Exception:
+            report_exception(c, 'cfg_after_cfg_callback')
+
+
 def before_component_callback(component, **kwargs):
     for c in callback_map['callbacks_before_component']:
         try:
@@ -231,7 +264,7 @@ def add_callback(callbacks, fun):
 
     callbacks.append(ScriptCallback(filename, fun))
 
-    
+
 def remove_current_script_callbacks():
     stack = [x for x in inspect.stack() if x.filename != __file__]
     filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@@ -254,6 +287,11 @@ def on_app_started(callback):
     add_callback(callback_map['callbacks_app_started'], callback)
 
 
+def on_before_reload(callback):
+    """register a function to be called just before the server reloads."""
+    add_callback(callback_map['callbacks_on_reload'], callback)
+
+
 def on_model_loaded(callback):
     """register a function to be called when the stable diffusion model is created; the model is
     passed as an argument; this function is also called when the script is reloaded. """
@@ -318,6 +356,14 @@ def on_cfg_denoised(callback):
     add_callback(callback_map['callbacks_cfg_denoised'], callback)
 
 
+def on_cfg_after_cfg(callback):
+    """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.
+    The callback is called with one argument:
+        - params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.
+    """
+    add_callback(callback_map['callbacks_cfg_after_cfg'], callback)
+
+
 def on_before_component(callback):
     """register a function to be called before a component is created.
     The callback is called with arguments:

+ 0 - 1
modules/script_loading.py

@@ -2,7 +2,6 @@ import os
 import sys
 import traceback
 import importlib.util
-from types import ModuleType
 
 
 def load_module(path):

+ 36 - 8
modules/scripts.py

@@ -17,6 +17,9 @@ class PostprocessImageArgs:
 
 
 class Script:
+    name = None
+    """script's internal name derived from title"""
+
     filename = None
     args_from = None
     args_to = None
@@ -25,8 +28,8 @@ class Script:
     is_txt2img = False
     is_img2img = False
 
-    """A gr.Group component that has all script's UI inside it"""
     group = None
+    """A gr.Group component that has all script's UI inside it"""
 
     infotext_fields = None
     """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
@@ -38,6 +41,9 @@ class Script:
     various "Send to <X>" buttons when clicked
     """
 
+    api_info = None
+    """Generated value of type modules.api.models.ScriptInfo with information about the script for API"""
+
     def title(self):
         """this function should return the title of the script. This is what will be displayed in the dropdown menu."""
 
@@ -163,7 +169,8 @@ class Script:
         """helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
 
         need_tabname = self.show(True) == self.show(False)
-        tabname = ('img2img' if self.is_img2img else 'txt2txt') + "_" if need_tabname else ""
+        tabkind = 'img2img' if self.is_img2img else 'txt2txt'
+        tabname = f"{tabkind}_" if need_tabname else ""
         title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
 
         return f'script_{tabname}{title}_{item_id}'
@@ -230,7 +237,7 @@ def load_scripts():
     syspath = sys.path
 
     def register_scripts_from_module(module):
-        for key, script_class in module.__dict__.items():
+        for script_class in module.__dict__.values():
             if type(script_class) != type:
                 continue
 
@@ -294,9 +301,9 @@ class ScriptRunner:
 
         auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
 
-        for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
-            script = script_class()
-            script.filename = path
+        for script_data in auto_processing_scripts + scripts_data:
+            script = script_data.script_class()
+            script.filename = script_data.path
             script.is_txt2img = not is_img2img
             script.is_img2img = is_img2img
 
@@ -312,6 +319,8 @@ class ScriptRunner:
                 self.selectable_scripts.append(script)
 
     def setup_ui(self):
+        import modules.api.models as api_models
+
         self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
 
         inputs = [None]
@@ -326,9 +335,28 @@ class ScriptRunner:
             if controls is None:
                 return
 
+            script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
+            api_args = []
+
             for control in controls:
                 control.custom_script_source = os.path.basename(script.filename)
 
+                arg_info = api_models.ScriptArg(label=control.label or "")
+
+                for field in ("value", "minimum", "maximum", "step", "choices"):
+                    v = getattr(control, field, None)
+                    if v is not None:
+                        setattr(arg_info, field, v)
+
+                api_args.append(arg_info)
+
+            script.api_info = api_models.ScriptInfo(
+                name=script.name,
+                is_img2img=script.is_img2img,
+                is_alwayson=script.alwayson,
+                args=api_args,
+            )
+
             if script.infotext_fields is not None:
                 self.infotext_fields += script.infotext_fields
 
@@ -491,7 +519,7 @@ class ScriptRunner:
                 module = script_loading.load_module(script.filename)
                 cache[filename] = module
 
-            for key, script_class in module.__dict__.items():
+            for script_class in module.__dict__.values():
                 if type(script_class) == type and issubclass(script_class, Script):
                     self.scripts[si] = script_class()
                     self.scripts[si].filename = filename
@@ -526,7 +554,7 @@ def add_classes_to_gradio_component(comp):
     this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
     """
 
-    comp.elem_classes = ["gradio-" + comp.get_block_name(), *(comp.elem_classes or [])]
+    comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
 
     if getattr(comp, 'multiselect', False):
         comp.elem_classes.append('multiselect')

+ 1 - 1
modules/scripts_auto_postprocessing.py

@@ -17,7 +17,7 @@ class ScriptPostprocessingForMainUI(scripts.Script):
         return self.postprocessing_controls.values()
 
     def postprocess_image(self, p, script_pp, *args):
-        args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)}
+        args_dict = dict(zip(self.postprocessing_controls, args))
 
         pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
         pp.info = {}

+ 4 - 4
modules/scripts_postprocessing.py

@@ -66,9 +66,9 @@ class ScriptPostprocessingRunner:
     def initialize_scripts(self, scripts_data):
         self.scripts = []
 
-        for script_class, path, basedir, script_module in scripts_data:
-            script: ScriptPostprocessing = script_class()
-            script.filename = path
+        for script_data in scripts_data:
+            script: ScriptPostprocessing = script_data.script_class()
+            script.filename = script_data.path
 
             if script.name == "Simple Upscale":
                 continue
@@ -124,7 +124,7 @@ class ScriptPostprocessingRunner:
             script_args = args[script.args_from:script.args_to]
 
             process_args = {}
-            for (name, component), value in zip(script.controls.items(), script_args):
+            for (name, _component), value in zip(script.controls.items(), script_args):
                 process_args[name] = value
 
             script.process(pp, **process_args)

+ 1 - 1
modules/sd_disable_initialization.py

@@ -61,7 +61,7 @@ class DisableInitialization:
                 if res is None:
                     res = original(url, *args, local_files_only=False, **kwargs)
                 return res
-            except Exception as e:
+            except Exception:
                 return original(url, *args, local_files_only=False, **kwargs)
 
         def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):

+ 13 - 10
modules/sd_hijack.py

@@ -3,7 +3,7 @@ from torch.nn.functional import silu
 from types import MethodType
 
 import modules.textual_inversion.textual_inversion
-from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
+from modules import devices, sd_hijack_optimizations, shared
 from modules.hypernetworks import hypernetwork
 from modules.shared import cmd_opts
 from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
@@ -34,10 +34,10 @@ def apply_optimizations():
 
     ldm.modules.diffusionmodules.model.nonlinearity = silu
     ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
-    
+
     optimization_method = None
 
-    can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp
+    can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) # not everyone has torch 2.x to use sdp
 
     if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
         print("Applying xformers cross attention optimization.")
@@ -92,12 +92,12 @@ def fix_checkpoint():
 def weighted_loss(sd_model, pred, target, mean=True):
     #Calculate the weight normally, but ignore the mean
     loss = sd_model._old_get_loss(pred, target, mean=False)
-    
+
     #Check if we have weights available
     weight = getattr(sd_model, '_custom_loss_weight', None)
     if weight is not None:
         loss *= weight
-    
+
     #Return the loss, as mean if specified
     return loss.mean() if mean else loss
 
@@ -105,7 +105,7 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs):
     try:
         #Temporarily append weights to a place accessible during loss calc
         sd_model._custom_loss_weight = w
-        
+
         #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
         #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
         if not hasattr(sd_model, '_old_get_loss'):
@@ -118,9 +118,9 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs):
         try:
             #Delete temporary weights if appended
             del sd_model._custom_loss_weight
-        except AttributeError as e:
+        except AttributeError:
             pass
-            
+
         #If we have an old loss function, reset the loss function to the original one
         if hasattr(sd_model, '_old_get_loss'):
             sd_model.get_loss = sd_model._old_get_loss
@@ -133,7 +133,7 @@ def apply_weighted_forward(sd_model):
 def undo_weighted_forward(sd_model):
     try:
         del sd_model.weighted_forward
-    except AttributeError as e:
+    except AttributeError:
         pass
 
 
@@ -184,7 +184,7 @@ class StableDiffusionModelHijack:
 
     def undo_hijack(self, m):
         if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
-            m.cond_stage_model = m.cond_stage_model.wrapped 
+            m.cond_stage_model = m.cond_stage_model.wrapped
 
         elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
             m.cond_stage_model = m.cond_stage_model.wrapped
@@ -216,6 +216,9 @@ class StableDiffusionModelHijack:
         self.comments = []
 
     def get_prompt_lengths(self, text):
+        if self.clip is None:
+            return "-", "-"
+
         _, token_count = self.clip.process_texts([text])
 
         return token_count, self.clip.get_target_prompt_token_count(token_count)

+ 1 - 1
modules/sd_hijack_clip.py

@@ -223,7 +223,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
             self.hijack.fixes = [x.fixes for x in batch_chunk]
 
             for fixes in self.hijack.fixes:
-                for position, embedding in fixes:
+                for _position, embedding in fixes:
                     used_embeddings[embedding.name] = embedding
 
             z = self.process_tokens(tokens, multipliers)

+ 2 - 1
modules/sd_hijack_clip_old.py

@@ -75,7 +75,8 @@ def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, text
     self.hijack.comments += hijack_comments
 
     if len(used_custom_terms) > 0:
-        self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
+        embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms)
+        self.hijack.comments.append(f"Used embeddings: {embedding_names}")
 
     self.hijack.fixes = hijack_fixes
     return self.process_tokens(remade_batch_tokens, batch_multipliers)

+ 2 - 8
modules/sd_hijack_inpainting.py

@@ -1,16 +1,10 @@
-import os
 import torch
 
-from einops import repeat
-from omegaconf import ListConfig
-
 import ldm.models.diffusion.ddpm
 import ldm.models.diffusion.ddim
 import ldm.models.diffusion.plms
 
-from ldm.models.diffusion.ddpm import LatentDiffusion
-from ldm.models.diffusion.plms import PLMSSampler
-from ldm.models.diffusion.ddim import DDIMSampler, noise_like
+from ldm.models.diffusion.ddim import noise_like
 from ldm.models.diffusion.sampling_util import norm_thresholding
 
 
@@ -29,7 +23,7 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
 
             if isinstance(c, dict):
                 assert isinstance(unconditional_conditioning, dict)
-                c_in = dict()
+                c_in = {}
                 for k in c:
                     if isinstance(c[k], list):
                         c_in[k] = [

+ 2 - 5
modules/sd_hijack_ip2p.py

@@ -1,8 +1,5 @@
-import collections
 import os.path
-import sys
-import gc
-import time
+
 
 def should_hijack_ip2p(checkpoint_info):
     from modules import sd_models_config
@@ -10,4 +7,4 @@ def should_hijack_ip2p(checkpoint_info):
     ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
     cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower()
 
-    return "pix2pix" in ckpt_basename and not "pix2pix" in cfg_basename
+    return "pix2pix" in ckpt_basename and "pix2pix" not in cfg_basename

+ 26 - 24
modules/sd_hijack_optimizations.py

@@ -49,7 +49,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
     v_in = self.to_v(context_v)
     del context, context_k, context_v, x
 
-    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
+    q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
     del q_in, k_in, v_in
 
     dtype = q.dtype
@@ -62,10 +62,10 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
             end = i + 2
             s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
             s1 *= self.scale
-    
+
             s2 = s1.softmax(dim=-1)
             del s1
-    
+
             r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
             del s2
         del q, k, v
@@ -95,43 +95,43 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
 
     with devices.without_autocast(disable=not shared.opts.upcast_attn):
         k_in = k_in * self.scale
-    
+
         del context, x
-    
-        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
+
+        q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
         del q_in, k_in, v_in
-    
+
         r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
-    
+
         mem_free_total = get_available_vram()
-    
+
         gb = 1024 ** 3
         tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
         modifier = 3 if q.element_size() == 2 else 2.5
         mem_required = tensor_size * modifier
         steps = 1
-    
+
         if mem_required > mem_free_total:
             steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
             # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
             #       f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
-    
+
         if steps > 64:
             max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
             raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
                                f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
-    
+
         slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
         for i in range(0, q.shape[1], slice_size):
             end = i + slice_size
             s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
-    
+
             s2 = s1.softmax(dim=-1, dtype=q.dtype)
             del s1
-    
+
             r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
             del s2
-    
+
         del q, k, v
 
     r1 = r1.to(dtype)
@@ -228,8 +228,8 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
 
     with devices.without_autocast(disable=not shared.opts.upcast_attn):
         k = k * self.scale
-    
-        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+        q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
         r = einsum_op(q, k, v)
     r = r.to(dtype)
     return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
@@ -256,6 +256,9 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
     k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
     v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
 
+    if q.device.type == 'mps':
+        q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
+
     dtype = q.dtype
     if shared.opts.upcast_attn:
         q, k = q.float(), k.float()
@@ -293,7 +296,6 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
     if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
         # the big matmul fits into our memory limit; do everything in 1 chunk,
         # i.e. send it down the unchunked fast-path
-        query_chunk_size = q_tokens
         kv_chunk_size = k_tokens
 
     with devices.without_autocast(disable=q.dtype == v.dtype):
@@ -332,7 +334,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
     k_in = self.to_k(context_k)
     v_in = self.to_v(context_v)
 
-    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
+    q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
     del q_in, k_in, v_in
 
     dtype = q.dtype
@@ -367,7 +369,7 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
     q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
     k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
     v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
-    
+
     del q_in, k_in, v_in
 
     dtype = q.dtype
@@ -449,7 +451,7 @@ def cross_attention_attnblock_forward(self, x):
         h3 += x
 
         return h3
-    
+
 def xformers_attnblock_forward(self, x):
     try:
         h_ = x
@@ -458,7 +460,7 @@ def xformers_attnblock_forward(self, x):
         k = self.k(h_)
         v = self.v(h_)
         b, c, h, w = q.shape
-        q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+        q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
         dtype = q.dtype
         if shared.opts.upcast_attn:
             q, k = q.float(), k.float()
@@ -480,7 +482,7 @@ def sdp_attnblock_forward(self, x):
     k = self.k(h_)
     v = self.v(h_)
     b, c, h, w = q.shape
-    q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+    q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
     dtype = q.dtype
     if shared.opts.upcast_attn:
         q, k = q.float(), k.float()
@@ -504,7 +506,7 @@ def sub_quad_attnblock_forward(self, x):
     k = self.k(h_)
     v = self.v(h_)
     b, c, h, w = q.shape
-    q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+    q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
     q = q.contiguous()
     k = k.contiguous()
     v = v.contiguous()

+ 1 - 1
modules/sd_hijack_unet.py

@@ -18,7 +18,7 @@ class TorchHijackForUnet:
         if hasattr(torch, item):
             return getattr(torch, item)
 
-        raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
+        raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
 
     def cat(self, tensors, *args, **kwargs):
         if len(tensors) == 2:

+ 0 - 2
modules/sd_hijack_xlmr.py

@@ -1,8 +1,6 @@
-import open_clip.tokenizer
 import torch
 
 from modules import sd_hijack_clip, devices
-from modules.shared import opts
 
 
 class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):

+ 89 - 27
modules/sd_models.py

@@ -2,6 +2,8 @@ import collections
 import os.path
 import sys
 import gc
+import threading
+
 import torch
 import re
 import safetensors.torch
@@ -13,9 +15,9 @@ import ldm.modules.midas as midas
 from ldm.util import instantiate_from_config
 
 from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
-from modules.paths import models_path
 from modules.sd_hijack_inpainting import do_inpainting_hijack
 from modules.timer import Timer
+import tomesd
 
 model_dir = "Stable-diffusion"
 model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
@@ -45,20 +47,29 @@ class CheckpointInfo:
         self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
         self.hash = model_hash(filename)
 
-        self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name)
+        self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{name}")
         self.shorthash = self.sha256[0:10] if self.sha256 else None
 
         self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
 
         self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
 
+        self.metadata = {}
+
+        _, ext = os.path.splitext(self.filename)
+        if ext.lower() == ".safetensors":
+            try:
+                self.metadata = read_metadata_from_safetensors(filename)
+            except Exception as e:
+                errors.display(e, f"reading checkpoint metadata: {filename}")
+
     def register(self):
         checkpoints_list[self.title] = self
         for id in self.ids:
             checkpoint_alisases[id] = self
 
     def calculate_shorthash(self):
-        self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
+        self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
         if self.sha256 is None:
             return
 
@@ -76,8 +87,7 @@ class CheckpointInfo:
 
 try:
     # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
-
-    from transformers import logging, CLIPModel
+    from transformers import logging, CLIPModel  # noqa: F401
 
     logging.set_verbosity_error()
 except Exception:
@@ -156,7 +166,7 @@ def model_hash(filename):
 
 def select_checkpoint():
     model_checkpoint = shared.opts.sd_model_checkpoint
-        
+
     checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
     if checkpoint_info is not None:
         return checkpoint_info
@@ -228,7 +238,7 @@ def read_metadata_from_safetensors(filename):
             if isinstance(v, str) and v[0:1] == '{':
                 try:
                     res[k] = json.loads(v)
-                except Exception as e:
+                except Exception:
                     pass
 
         return res
@@ -363,7 +373,7 @@ def enable_midas_autodownload():
         if not os.path.exists(path):
             if not os.path.exists(midas_path):
                 mkdir(midas_path)
-    
+
             print(f"Downloading midas model weights for {model_type} to {path}")
             request.urlretrieve(midas_urls[model_type], path)
             print(f"{model_type} downloaded")
@@ -395,13 +405,42 @@ def repair_config(sd_config):
 sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
 sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
 
-def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
+
+class SdModelData:
+    def __init__(self):
+        self.sd_model = None
+        self.lock = threading.Lock()
+
+    def get_sd_model(self):
+        if self.sd_model is None:
+            with self.lock:
+                if self.sd_model is not None:
+                    return self.sd_model
+
+                try:
+                    load_model()
+                except Exception as e:
+                    errors.display(e, "loading stable diffusion model")
+                    print("", file=sys.stderr)
+                    print("Stable diffusion model failed to load", file=sys.stderr)
+                    self.sd_model = None
+
+        return self.sd_model
+
+    def set_sd_model(self, v):
+        self.sd_model = v
+
+
+model_data = SdModelData()
+
+
+def load_model(checkpoint_info=None, already_loaded_state_dict=None):
     from modules import lowvram, sd_hijack
     checkpoint_info = checkpoint_info or select_checkpoint()
 
-    if shared.sd_model:
-        sd_hijack.model_hijack.undo_hijack(shared.sd_model)
-        shared.sd_model = None
+    if model_data.sd_model:
+        sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
+        model_data.sd_model = None
         gc.collect()
         devices.torch_gc()
 
@@ -430,7 +469,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
     try:
         with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
             sd_model = instantiate_from_config(sd_config.model)
-    except Exception as e:
+    except Exception:
         pass
 
     if sd_model is None:
@@ -455,7 +494,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
     timer.record("hijack")
 
     sd_model.eval()
-    shared.sd_model = sd_model
+    model_data.sd_model = sd_model
 
     sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)  # Reload embeddings after model load as they may or may not fit the model
 
@@ -475,7 +514,7 @@ def reload_model_weights(sd_model=None, info=None):
     checkpoint_info = info or select_checkpoint()
 
     if not sd_model:
-        sd_model = shared.sd_model
+        sd_model = model_data.sd_model
 
     if sd_model is None:  # previous model load failed
         current_checkpoint_info = None
@@ -501,13 +540,12 @@ def reload_model_weights(sd_model=None, info=None):
 
     if sd_model is None or checkpoint_config != sd_model.used_config:
         del sd_model
-        checkpoints_loaded.clear()
         load_model(checkpoint_info, already_loaded_state_dict=state_dict)
-        return shared.sd_model
+        return model_data.sd_model
 
     try:
         load_model_weights(sd_model, checkpoint_info, state_dict, timer)
-    except Exception as e:
+    except Exception:
         print("Failed to load checkpoint, restoring previous")
         load_model_weights(sd_model, current_checkpoint_info, None, timer)
         raise
@@ -526,17 +564,15 @@ def reload_model_weights(sd_model=None, info=None):
 
     return sd_model
 
+
 def unload_model_weights(sd_model=None, info=None):
-    from modules import lowvram, devices, sd_hijack
+    from modules import devices, sd_hijack
     timer = Timer()
 
-    if shared.sd_model:
-
-        # shared.sd_model.cond_stage_model.to(devices.cpu)
-        # shared.sd_model.first_stage_model.to(devices.cpu)
-        shared.sd_model.to(devices.cpu)
-        sd_hijack.model_hijack.undo_hijack(shared.sd_model)
-        shared.sd_model = None
+    if model_data.sd_model:
+        model_data.sd_model.to(devices.cpu)
+        sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
+        model_data.sd_model = None
         sd_model = None
         gc.collect()
         devices.torch_gc()
@@ -544,4 +580,30 @@ def unload_model_weights(sd_model=None, info=None):
 
     print(f"Unloaded weights {timer.summary()}.")
 
-    return sd_model
+    return sd_model
+
+
+def apply_token_merging(sd_model, token_merging_ratio):
+    """
+    Applies speed and memory optimizations from tomesd.
+    """
+
+    current_token_merging_ratio = getattr(sd_model, 'applied_token_merged_ratio', 0)
+
+    if current_token_merging_ratio == token_merging_ratio:
+        return
+
+    if current_token_merging_ratio > 0:
+        tomesd.remove_patch(sd_model)
+
+    if token_merging_ratio > 0:
+        tomesd.apply_patch(
+            sd_model,
+            ratio=token_merging_ratio,
+            use_rand=False,  # can cause issues with some samplers
+            merge_attn=True,
+            merge_crossattn=False,
+            merge_mlp=False
+        )
+
+    sd_model.applied_token_merged_ratio = token_merging_ratio

+ 1 - 2
modules/sd_models_config.py

@@ -1,4 +1,3 @@
-import re
 import os
 
 import torch
@@ -111,7 +110,7 @@ def find_checkpoint_config_near_filename(info):
     if info is None:
         return None
 
-    config = os.path.splitext(info.filename)[0] + ".yaml"
+    config = f"{os.path.splitext(info.filename)[0]}.yaml"
     if os.path.exists(config):
         return config
 

+ 8 - 2
modules/sd_samplers.py

@@ -1,7 +1,7 @@
 from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
 
 # imports for functions that previously were here and are used by other modules
-from modules.sd_samplers_common import samples_to_image_grid, sample_to_image
+from modules.sd_samplers_common import samples_to_image_grid, sample_to_image  # noqa: F401
 
 all_samplers = [
     *sd_samplers_kdiffusion.samplers_data_k_diffusion,
@@ -14,12 +14,18 @@ samplers_for_img2img = []
 samplers_map = {}
 
 
-def create_sampler(name, model):
+def find_sampler_config(name):
     if name is not None:
         config = all_samplers_map.get(name, None)
     else:
         config = all_samplers[0]
 
+    return config
+
+
+def create_sampler(name, model):
+    config = find_sampler_config(name)
+
     assert config is not None, f'bad sampler name: {name}'
 
     sampler = config.constructor(model)

Some files were not shown because too many files changed in this diff