Browse Source

add built-in extension system
add support for adding upscalers in extensions
move LDSR, ScuNET and SwinIR to built-in extensions

AUTOMATIC 2 years ago
parent
commit
b6e5edd746

+ 0 - 0
modules/ldsr_model_arch.py → extensions-builtin/LDSR/ldsr_model_arch.py


+ 6 - 0
extensions-builtin/LDSR/preload.py

@@ -0,0 +1,6 @@
+import os
+from modules import paths
+
+
+def preload(parser):
+    parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR'))

+ 11 - 2
modules/ldsr_model.py → extensions-builtin/LDSR/scripts/ldsr_model.py

@@ -5,8 +5,8 @@ import traceback
 from basicsr.utils.download_util import load_file_from_url
 
 from modules.upscaler import Upscaler, UpscalerData
-from modules.ldsr_model_arch import LDSR
-from modules import shared
+from ldsr_model_arch import LDSR
+from modules import shared, script_callbacks
 
 
 class UpscalerLDSR(Upscaler):
@@ -52,3 +52,12 @@ class UpscalerLDSR(Upscaler):
             return img
         ddim_steps = shared.opts.ldsr_steps
         return ldsr.super_resolution(img, ddim_steps, self.scale)
+
+
+def on_ui_settings():
+    import gradio as gr
+
+    shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling")))
+
+
+script_callbacks.on_ui_settings(on_ui_settings)

+ 6 - 0
extensions-builtin/ScuNET/preload.py

@@ -0,0 +1,6 @@
+import os
+from modules import paths
+
+
+def preload(parser):
+    parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(paths.models_path, 'ScuNET'))

+ 3 - 3
modules/scunet_model.py → extensions-builtin/ScuNET/scripts/scunet_model.py

@@ -9,7 +9,7 @@ from basicsr.utils.download_util import load_file_from_url
 
 import modules.upscaler
 from modules import devices, modelloader
-from modules.scunet_model_arch import SCUNet as net
+from scunet_model_arch import SCUNet as net
 
 
 class UpscalerScuNET(modules.upscaler.Upscaler):
@@ -49,7 +49,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
         if model is None:
             return img
 
-        device = devices.device_scunet
+        device = devices.get_device_for('scunet')
         img = np.array(img)
         img = img[:, :, ::-1]
         img = np.moveaxis(img, 2, 0) / 255
@@ -66,7 +66,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
         return PIL.Image.fromarray(output, 'RGB')
 
     def load_model(self, path: str):
-        device = devices.device_scunet
+        device = devices.get_device_for('scunet')
         if "http" in path:
             filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
                                           progress=True)

+ 0 - 0
modules/scunet_model_arch.py → extensions-builtin/ScuNET/scunet_model_arch.py


+ 6 - 0
extensions-builtin/SwinIR/preload.py

@@ -0,0 +1,6 @@
+import os
+from modules import paths
+
+
+def preload(parser):
+    parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(paths.models_path, 'SwinIR'))

+ 20 - 9
modules/swinir_model.py → extensions-builtin/SwinIR/scripts/swinir_model.py

@@ -7,13 +7,16 @@ from PIL import Image
 from basicsr.utils.download_util import load_file_from_url
 from tqdm import tqdm
 
-from modules import modelloader, devices
+from modules import modelloader, devices, script_callbacks, shared
 from modules.shared import cmd_opts, opts
-from modules.swinir_model_arch import SwinIR as net
-from modules.swinir_model_arch_v2 import Swin2SR as net2
+from swinir_model_arch import SwinIR as net
+from swinir_model_arch_v2 import Swin2SR as net2
 from modules.upscaler import Upscaler, UpscalerData
 
 
+device_swinir = devices.get_device_for('swinir')
+
+
 class UpscalerSwinIR(Upscaler):
     def __init__(self, dirname):
         self.name = "SwinIR"
@@ -38,7 +41,7 @@ class UpscalerSwinIR(Upscaler):
         model = self.load_model(model_file)
         if model is None:
             return img
-        model = model.to(devices.device_swinir)
+        model = model.to(device_swinir, dtype=devices.dtype)
         img = upscale(img, model)
         try:
             torch.cuda.empty_cache()
@@ -90,8 +93,6 @@ class UpscalerSwinIR(Upscaler):
             model.load_state_dict(pretrained_model[params], strict=True)
         else:
             model.load_state_dict(pretrained_model, strict=True)
-        if not cmd_opts.no_half:
-            model = model.half()
         return model
 
 
@@ -107,7 +108,7 @@ def upscale(
     img = img[:, :, ::-1]
     img = np.moveaxis(img, 2, 0) / 255
     img = torch.from_numpy(img).float()
-    img = img.unsqueeze(0).to(devices.device_swinir)
+    img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype)
     with torch.no_grad(), devices.autocast():
         _, _, h_old, w_old = img.size()
         h_pad = (h_old // window_size + 1) * window_size - h_old
@@ -135,8 +136,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
     stride = tile - tile_overlap
     h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
     w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
-    E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=devices.device_swinir).type_as(img)
-    W = torch.zeros_like(E, dtype=torch.half, device=devices.device_swinir)
+    E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img)
+    W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir)
 
     with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
         for h_idx in h_idx_list:
@@ -155,3 +156,13 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
     output = E.div_(W)
 
     return output
+
+
+def on_ui_settings():
+    import gradio as gr
+
+    shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
+    shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
+
+
+script_callbacks.on_ui_settings(on_ui_settings)

+ 0 - 0
modules/swinir_model_arch.py → extensions-builtin/SwinIR/swinir_model_arch.py


+ 0 - 0
modules/swinir_model_arch_v2.py → extensions-builtin/SwinIR/swinir_model_arch_v2.py


+ 10 - 1
modules/devices.py

@@ -44,6 +44,15 @@ def get_optimal_device():
     return cpu
 
 
+def get_device_for(task):
+    from modules import shared
+
+    if task in shared.cmd_opts.use_cpu:
+        return cpu
+
+    return get_optimal_device()
+
+
 def torch_gc():
     if torch.cuda.is_available():
         with torch.cuda.device(get_cuda_device_string()):
@@ -67,7 +76,7 @@ def enable_tf32():
 errors.run(enable_tf32, "Enabling TF32")
 
 cpu = torch.device("cpu")
-device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None
+device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
 dtype = torch.float16
 dtype_vae = torch.float16
 

+ 16 - 6
modules/extensions.py

@@ -8,6 +8,7 @@ from modules import paths, shared
 
 extensions = []
 extensions_dir = os.path.join(paths.script_path, "extensions")
+extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
 
 
 def active():
@@ -15,12 +16,13 @@ def active():
 
 
 class Extension:
-    def __init__(self, name, path, enabled=True):
+    def __init__(self, name, path, enabled=True, is_builtin=False):
         self.name = name
         self.path = path
         self.enabled = enabled
         self.status = ''
         self.can_update = False
+        self.is_builtin = is_builtin
 
         repo = None
         try:
@@ -79,11 +81,19 @@ def list_extensions():
     if not os.path.isdir(extensions_dir):
         return
 
-    for dirname in sorted(os.listdir(extensions_dir)):
-        path = os.path.join(extensions_dir, dirname)
-        if not os.path.isdir(path):
-            continue
+    paths = []
+    for dirname in [extensions_dir, extensions_builtin_dir]:
+        if not os.path.isdir(dirname):
+            return
 
-        extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
+        for extension_dirname in sorted(os.listdir(dirname)):
+            path = os.path.join(dirname, extension_dirname)
+            if not os.path.isdir(path):
+                continue
+
+            paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
+
+    for dirname, path, is_builtin in paths:
+        extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
         extensions.append(extension)
 

+ 5 - 15
modules/modelloader.py

@@ -124,10 +124,9 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
 
 
 def load_upscalers():
-    sd = shared.script_path
     # We can only do this 'magic' method to dynamically load upscalers if they are referenced,
     # so we'll try to import any _model.py files before looking in __subclasses__
-    modules_dir = os.path.join(sd, "modules")
+    modules_dir = os.path.join(shared.script_path, "modules")
     for file in os.listdir(modules_dir):
         if "_model.py" in file:
             model_name = file.replace("_model.py", "")
@@ -136,22 +135,13 @@ def load_upscalers():
                 importlib.import_module(full_model)
             except:
                 pass
+
     datas = []
-    c_o = vars(shared.cmd_opts)
+    commandline_options = vars(shared.cmd_opts)
     for cls in Upscaler.__subclasses__():
         name = cls.__name__
-        module_name = cls.__module__
-        module = importlib.import_module(module_name)
-        class_ = getattr(module, name)
         cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
-        opt_string = None
-        try:
-            if cmd_name in c_o:
-                opt_string = c_o[cmd_name]
-        except:
-            pass
-        scaler = class_(opt_string)
-        for child in scaler.scalers:
-            datas.append(child)
+        scaler = cls(commandline_options.get(cmd_name, None))
+        datas += scaler.scalers
 
     shared.sd_upscalers = datas

+ 4 - 9
modules/shared.py

@@ -50,9 +50,6 @@ parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory wi
 parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
 parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
 parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
-parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
-parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
-parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
 parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
 parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
 parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
@@ -61,7 +58,7 @@ parser.add_argument("--opt-split-attention", action='store_true', help="force-en
 parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
 parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
 parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
-parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower)
+parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
 parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
 parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
 parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
@@ -95,6 +92,7 @@ parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, req
 parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
 
 script_loading.preload_extensions(extensions.extensions_dir, parser)
+script_loading.preload_extensions(extensions.extensions_builtin_dir, parser)
 
 cmd_opts = parser.parse_args()
 
@@ -112,8 +110,8 @@ restricted_opts = {
 
 cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
 
-devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
-(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'])
+devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
+    (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
 
 device = devices.device
 weight_load_location = None if cmd_opts.lowram else "cpu"
@@ -326,9 +324,6 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
     "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
     "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
     "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
-    "SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
-    "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
-    "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
     "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
     "use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"),
 }))

+ 0 - 1
modules/ui.py

@@ -28,7 +28,6 @@ import modules.codeformer_model
 import modules.generation_parameters_copypaste as parameters_copypaste
 import modules.gfpgan_model
 import modules.hypernetworks.ui
-import modules.ldsr_model
 import modules.scripts
 import modules.shared as shared
 import modules.styles

+ 7 - 1
modules/ui_extensions.py

@@ -78,6 +78,12 @@ def extension_table():
     """
 
     for ext in extensions.extensions:
+        remote = ""
+        if ext.is_builtin:
+            remote = "built-in"
+        elif ext.remote:
+            remote = f"""<a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape("built-in" if ext.is_builtin else ext.remote or '')}</a>"""
+
         if ext.can_update:
             ext_status = f"""<label><input class="gr-check-radio gr-checkbox" name="update_{html.escape(ext.name)}" checked="checked" type="checkbox">{html.escape(ext.status)}</label>"""
         else:
@@ -86,7 +92,7 @@ def extension_table():
         code += f"""
             <tr>
                 <td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
-                <td><a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape(ext.remote or '')}</a></td>
+                <td>{remote}</td>
                 <td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
             </tr>
     """

+ 4 - 1
webui.py

@@ -53,10 +53,11 @@ def initialize():
     codeformer.setup_model(cmd_opts.codeformer_models_path)
     gfpgan.setup_model(cmd_opts.gfpgan_models_path)
     shared.face_restorers.append(modules.face_restoration.FaceRestoration())
-    modelloader.load_upscalers()
 
     modules.scripts.load_scripts()
 
+    modelloader.load_upscalers()
+
     modules.sd_vae.refresh_vae_list()
     modules.sd_models.load_model()
     shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
@@ -177,6 +178,8 @@ def webui():
 
         print('Reloading custom scripts')
         modules.scripts.reload_scripts()
+        modelloader.load_upscalers()
+
         print('Reloading modules: modules.ui')
         importlib.reload(modules.ui)
         print('Refreshing Model List')