浏览代码

make existing script loading and new preload code use same code for loading modules
limit extension preload scripts to just one file named preload.py

AUTOMATIC 2 年之前
父节点
当前提交
a1a376331c
共有 4 个文件被更改,包括 53 次插入53 次删除
  1. 0 21
      modules/extensions.py
  2. 34 0
      modules/script_loading.py
  3. 17 29
      modules/scripts.py
  4. 2 3
      modules/shared.py

+ 0 - 21
modules/extensions.py

@@ -1,7 +1,6 @@
 import os
 import sys
 import traceback
-from importlib.machinery import SourceFileLoader
 
 import git
 
@@ -85,23 +84,3 @@ def list_extensions():
         extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
         extensions.append(extension)
 
-
-def preload_extensions(parser):
-    if not os.path.isdir(extensions_dir):
-        return
-
-    for dirname in sorted(os.listdir(extensions_dir)):
-        path = os.path.join(extensions_dir, dirname)
-        if not os.path.isdir(path):
-            continue
-        for file in os.listdir(path):
-            if "preload.py" in file:
-                full_file = os.path.join(path, file)
-                print(f"Got preload file: {full_file}")
-
-                try:
-                    ext = SourceFileLoader("preload", full_file).load_module()
-                    parser = ext.preload(parser)
-                except Exception as e:
-                    print(f"Exception preloading script: {e}")
-    return parser

+ 34 - 0
modules/script_loading.py

@@ -0,0 +1,34 @@
+import os
+import sys
+import traceback
+from types import ModuleType
+
+
+def load_module(path):
+    with open(path, "r", encoding="utf8") as file:
+        text = file.read()
+
+    compiled = compile(text, path, 'exec')
+    module = ModuleType(os.path.basename(path))
+    exec(compiled, module.__dict__)
+
+    return module
+
+
+def preload_extensions(extensions_dir, parser):
+    if not os.path.isdir(extensions_dir):
+        return
+
+    for dirname in sorted(os.listdir(extensions_dir)):
+        preload_script = os.path.join(extensions_dir, dirname, "preload.py")
+        if not os.path.isfile(preload_script):
+            continue
+
+        try:
+            module = load_module(preload_script)
+            if hasattr(module, 'preload'):
+                module.preload(parser)
+
+        except Exception:
+            print(f"Error running preload() for {preload_script}", file=sys.stderr)
+            print(traceback.format_exc(), file=sys.stderr)

+ 17 - 29
modules/scripts.py

@@ -6,7 +6,7 @@ from collections import namedtuple
 import gradio as gr
 
 from modules.processing import StableDiffusionProcessing
-from modules import shared, paths, script_callbacks, extensions
+from modules import shared, paths, script_callbacks, extensions, script_loading
 
 AlwaysVisible = object()
 
@@ -161,13 +161,7 @@ def load_scripts():
                 sys.path = [scriptfile.basedir] + sys.path
             current_basedir = scriptfile.basedir
 
-            with open(scriptfile.path, "r", encoding="utf8") as file:
-                text = file.read()
-
-            from types import ModuleType
-            compiled = compile(text, scriptfile.path, 'exec')
-            module = ModuleType(scriptfile.filename)
-            exec(compiled, module.__dict__)
+            module = script_loading.load_module(scriptfile.path)
 
             for key, script_class in module.__dict__.items():
                 if type(script_class) == type and issubclass(script_class, Script):
@@ -328,27 +322,21 @@ class ScriptRunner:
 
     def reload_sources(self, cache):
         for si, script in list(enumerate(self.scripts)):
-            with open(script.filename, "r", encoding="utf8") as file:
-                args_from = script.args_from
-                args_to = script.args_to
-                filename = script.filename
-                text = file.read()
-
-                from types import ModuleType
-
-                module = cache.get(filename, None)
-                if module is None:
-                    compiled = compile(text, filename, 'exec')
-                    module = ModuleType(script.filename)
-                    exec(compiled, module.__dict__)
-                    cache[filename] = module
-
-                for key, script_class in module.__dict__.items():
-                    if type(script_class) == type and issubclass(script_class, Script):
-                        self.scripts[si] = script_class()
-                        self.scripts[si].filename = filename
-                        self.scripts[si].args_from = args_from
-                        self.scripts[si].args_to = args_to
+            args_from = script.args_from
+            args_to = script.args_to
+            filename = script.filename
+
+            module = cache.get(filename, None)
+            if module is None:
+                module = script_loading.load_module(script.filename)
+                cache[filename] = module
+
+            for key, script_class in module.__dict__.items():
+                if type(script_class) == type and issubclass(script_class, Script):
+                    self.scripts[si] = script_class()
+                    self.scripts[si].filename = filename
+                    self.scripts[si].args_from = args_from
+                    self.scripts[si].args_to = args_to
 
 
 scripts_txt2img = ScriptRunner()

+ 2 - 3
modules/shared.py

@@ -3,7 +3,6 @@ import datetime
 import json
 import os
 import sys
-from collections import OrderedDict
 import time
 
 import gradio as gr
@@ -15,7 +14,7 @@ import modules.memmon
 import modules.sd_models
 import modules.styles
 import modules.devices as devices
-from modules import sd_samplers, sd_models, localization, sd_vae, extensions
+from modules import sd_samplers, sd_models, localization, sd_vae, extensions, script_loading
 from modules.hypernetworks import hypernetwork
 from modules.paths import models_path, script_path, sd_path
 
@@ -91,7 +90,7 @@ parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requ
 parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
 parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
 
-extensions.preload_extensions(parser)
+script_loading.preload_extensions(extensions.extensions_dir, parser)
 
 cmd_opts = parser.parse_args()