Bläddra i källkod

add missing import, simplify code, use patches module for #13276

AUTOMATIC1111 1 år sedan
förälder
incheckning
87b50397a6
1 ändrade filer med 12 tillägg och 7 borttagningar
  1. 12 7
      modules/sd_models.py

+ 12 - 7
modules/sd_models.py

@@ -14,7 +14,7 @@ import ldm.modules.midas as midas
 
 from ldm.util import instantiate_from_config
 
-from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack
+from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
 from modules.timer import Timer
 import tomesd
 import numpy as np
@@ -130,6 +130,8 @@ except Exception:
 
 
 def setup_model():
+    """called once at startup to do various one-time tasks related to SD models"""
+
     os.makedirs(model_path, exist_ok=True)
 
     enable_midas_autodownload()
@@ -458,14 +460,17 @@ def enable_midas_autodownload():
 
 
 def patch_given_betas():
-    original_register_schedule = ldm.models.diffusion.ddpm.DDPM.register_schedule
+    import ldm.models.diffusion.ddpm
+
     def patched_register_schedule(*args, **kwargs):
-        if args[1] is not None and isinstance(args[1], ListConfig):
-            modified_args = list(args)  # Convert args tuple to a list
-            modified_args[1] = np.array(args[1])  # Modify the desired element
-            args = tuple(modified_args)  # Convert the list back to a tuple
+        """a modified version of register_schedule function that converts plain list from Omegaconf into numpy"""
+
+        if isinstance(args[1], ListConfig):
+            args = (args[0], np.array(args[1]), *args[2:])
+
         original_register_schedule(*args, **kwargs)
-    ldm.models.diffusion.ddpm.DDPM.register_schedule = patched_register_schedule
+
+    original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
 
 
 def repair_config(sd_config):