Browse Source

Re-implement universal model loading

d8ahazard 2 years ago
parent
commit
740070ea9c

+ 25 - 10
modules/codeformer_model.py

@@ -5,22 +5,28 @@ import traceback
 import cv2
 import cv2
 import torch
 import torch
 
 
-from modules import shared, devices
-from modules.paths import script_path
+from modules import shared, devices, modelloader
+from modules.paths import script_path, models_path
 import modules.shared
 import modules.shared
 import modules.face_restoration
 import modules.face_restoration
 from importlib import reload
 from importlib import reload
 
 
-# codeformer people made a choice to include modified basicsr librry to their projectwhich makes
-# it utterly impossiblr to use it alongside with other libraries that also use basicsr, like GFPGAN.
+# codeformer people made a choice to include modified basicsr library to their project, which makes
+# it utterly impossible to use it alongside other libraries that also use basicsr, like GFPGAN.
 # I am making a choice to include some files from codeformer to work around this issue.
 # I am making a choice to include some files from codeformer to work around this issue.
-
-pretrain_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
+model_dir = "Codeformer"
+model_path = os.path.join(models_path, model_dir)
+model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
 
 
 have_codeformer = False
 have_codeformer = False
 codeformer = None
 codeformer = None
 
 
-def setup_codeformer():
+
+def setup_model(dirname):
+    global model_path
+    if not os.path.exists(model_path):
+        os.makedirs(model_path)
+
     path = modules.paths.paths.get("CodeFormer", None)
     path = modules.paths.paths.get("CodeFormer", None)
     if path is None:
     if path is None:
         return
         return
@@ -44,16 +50,22 @@ def setup_codeformer():
             def name(self):
             def name(self):
                 return "CodeFormer"
                 return "CodeFormer"
 
 
-            def __init__(self):
+            def __init__(self, dirname):
                 self.net = None
                 self.net = None
                 self.face_helper = None
                 self.face_helper = None
+                self.cmd_dir = dirname
 
 
             def create_models(self):
             def create_models(self):
 
 
                 if self.net is not None and self.face_helper is not None:
                 if self.net is not None and self.face_helper is not None:
                     self.net.to(devices.device_codeformer)
                     self.net.to(devices.device_codeformer)
                     return self.net, self.face_helper
                     return self.net, self.face_helper
-
+                model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir)
+                if len(model_paths) != 0:
+                    ckpt_path = model_paths[0]
+                else:
+                    print("Unable to load codeformer model.")
+                    return None, None
                 net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
                 net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
                 ckpt_path = load_file_from_url(url=pretrain_model_url, model_dir=os.path.join(path, 'weights/CodeFormer'), progress=True)
                 ckpt_path = load_file_from_url(url=pretrain_model_url, model_dir=os.path.join(path, 'weights/CodeFormer'), progress=True)
                 checkpoint = torch.load(ckpt_path)['params_ema']
                 checkpoint = torch.load(ckpt_path)['params_ema']
@@ -74,6 +86,9 @@ def setup_codeformer():
                 original_resolution = np_image.shape[0:2]
                 original_resolution = np_image.shape[0:2]
 
 
                 self.create_models()
                 self.create_models()
+                if self.net is None or self.face_helper is None:
+                    return np_image
+
                 self.face_helper.clean_all()
                 self.face_helper.clean_all()
                 self.face_helper.read_image(np_image)
                 self.face_helper.read_image(np_image)
                 self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
                 self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
@@ -114,7 +129,7 @@ def setup_codeformer():
         have_codeformer = True
         have_codeformer = True
 
 
         global codeformer
         global codeformer
-        codeformer = FaceRestorerCodeFormer()
+        codeformer = FaceRestorerCodeFormer(dirname)
         shared.face_restorers.append(codeformer)
         shared.face_restorers.append(codeformer)
 
 
     except Exception:
     except Exception:

+ 41 - 15
modules/esrgan_model.py

@@ -5,15 +5,35 @@ import traceback
 import numpy as np
 import numpy as np
 import torch
 import torch
 from PIL import Image
 from PIL import Image
+from basicsr.utils.download_util import load_file_from_url
 
 
 import modules.esrgam_model_arch as arch
 import modules.esrgam_model_arch as arch
+import modules.images
 from modules import shared
 from modules import shared
-from modules.shared import opts
+from modules import shared, modelloader
 from modules.devices import has_mps
 from modules.devices import has_mps
-import modules.images
-
+from modules.paths import models_path
+from modules.shared import opts
 
 
-def load_model(filename):
+model_dir = "ESRGAN"
+model_path = os.path.join(models_path, model_dir)
+model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
+model_name = "ESRGAN_x4.pth"
+
+
+def load_model(path: str, name: str):
+    global model_path
+    global model_url
+    global model_dir
+    global model_name
+    if "http" in path:
+        filename = load_file_from_url(url=model_url, model_dir=model_path, file_name=model_name, progress=True)
+    else:
+        filename = path
+    if not os.path.exists(filename) or filename is None:
+        print("Unable to load %s from %s" % (model_dir, filename))
+        return None
+    print("Loading %s from %s" % (model_dir, filename))
     # this code is adapted from https://github.com/xinntao/ESRGAN
     # this code is adapted from https://github.com/xinntao/ESRGAN
     pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
     pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
     crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
     crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
@@ -118,24 +138,30 @@ def esrgan_upscale(model, img):
 class UpscalerESRGAN(modules.images.Upscaler):
 class UpscalerESRGAN(modules.images.Upscaler):
     def __init__(self, filename, title):
     def __init__(self, filename, title):
         self.name = title
         self.name = title
-        self.model = load_model(filename)
+        self.filename = filename
 
 
     def do_upscale(self, img):
     def do_upscale(self, img):
-        model = self.model.to(shared.device)
+        model = load_model(self.filename, self.name)
+        if model is None:
+            return img
+        model.to(shared.device)
         img = esrgan_upscale(model, img)
         img = esrgan_upscale(model, img)
         return img
         return img
 
 
 
 
-def load_models(dirname):
-    for file in os.listdir(dirname):
-        path = os.path.join(dirname, file)
-        model_name, extension = os.path.splitext(file)
-
-        if extension != '.pt' and extension != '.pth':
-            continue
+def setup_model(dirname):
+    global model_path
+    global model_name
+    if not os.path.exists(model_path):
+        os.makedirs(model_path)
 
 
+    model_paths = modelloader.load_models(model_path, command_path=dirname, ext_filter=[".pt", ".pth"])
+    if len(model_paths) == 0:
+        modules.shared.sd_upscalers.append(UpscalerESRGAN(model_url, model_name))
+    for file in model_paths:
+        name = modelloader.friendly_name(file)
         try:
         try:
-            modules.shared.sd_upscalers.append(UpscalerESRGAN(path, model_name))
+            modules.shared.sd_upscalers.append(UpscalerESRGAN(file, name))
         except Exception:
         except Exception:
-            print(f"Error loading ESRGAN model: {path}", file=sys.stderr)
+            print(f"Error loading ESRGAN model: {file}", file=sys.stderr)
             print(traceback.format_exc(), file=sys.stderr)
             print(traceback.format_exc(), file=sys.stderr)

+ 2 - 0
modules/extras.py

@@ -36,6 +36,8 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
 
 
     outputs = []
     outputs = []
     for image, image_name in zip(imageArr, imageNameArr):
     for image, image_name in zip(imageArr, imageNameArr):
+        if image is None:
+            return outputs, "Please select an input image.", ''
         existing_pnginfo = image.info or {}
         existing_pnginfo = image.info or {}
 
 
         image = image.convert("RGB")
         image = image.convert("RGB")

+ 29 - 31
modules/gfpgan_model.py

@@ -7,33 +7,20 @@ from modules import shared, devices
 from modules.shared import cmd_opts
 from modules.shared import cmd_opts
 from modules.paths import script_path
 from modules.paths import script_path
 import modules.face_restoration
 import modules.face_restoration
+from modules import shared, devices, modelloader
+from modules.paths import models_path
 
 
-
-def gfpgan_model_path():
-    from modules.shared import cmd_opts
-
-    filemask = 'GFPGAN*.pth'
-
-    if cmd_opts.gfpgan_model is not None:
-        return cmd_opts.gfpgan_model
-
-    places = [script_path, '.', os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models')]
-
-    filename = None
-    for place in places:
-        filename = next(iter(glob(os.path.join(place, filemask))), None)
-        if filename is not None:
-            break
-
-    return filename
-
+model_dir = "GFPGAN"
+cmd_dir = None
+model_path = os.path.join(models_path, model_dir)
+model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
 
 
 loaded_gfpgan_model = None
 loaded_gfpgan_model = None
 
 
 
 
 def gfpgan():
 def gfpgan():
     global loaded_gfpgan_model
     global loaded_gfpgan_model
-
+    global model_path
     if loaded_gfpgan_model is not None:
     if loaded_gfpgan_model is not None:
         loaded_gfpgan_model.gfpgan.to(shared.device)
         loaded_gfpgan_model.gfpgan.to(shared.device)
         return loaded_gfpgan_model
         return loaded_gfpgan_model
@@ -41,7 +28,15 @@ def gfpgan():
     if gfpgan_constructor is None:
     if gfpgan_constructor is None:
         return None
         return None
 
 
-    model = gfpgan_constructor(model_path=gfpgan_model_path() or 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth', upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
+    models = modelloader.load_models(model_path, model_url, cmd_dir)
+    if len(models) != 0:
+        latest_file = max(models, key=os.path.getctime)
+        model_file = latest_file
+    else:
+        print("Unable to load gfpgan model!")
+        return None
+    model = gfpgan_constructor(model_path=model_file, model_dir=model_path, upscale=1, arch='clean', channel_multiplier=2,
+                               bg_upsampler=None)
     model.gfpgan.to(shared.device)
     model.gfpgan.to(shared.device)
     loaded_gfpgan_model = model
     loaded_gfpgan_model = model
 
 
@@ -50,7 +45,8 @@ def gfpgan():
 
 
 def gfpgan_fix_faces(np_image):
 def gfpgan_fix_faces(np_image):
     model = gfpgan()
     model = gfpgan()
-
+    if model is None:
+        return np_image
     np_image_bgr = np_image[:, :, ::-1]
     np_image_bgr = np_image[:, :, ::-1]
     cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
     cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
     np_image = gfpgan_output_bgr[:, :, ::-1]
     np_image = gfpgan_output_bgr[:, :, ::-1]
@@ -64,19 +60,21 @@ def gfpgan_fix_faces(np_image):
 have_gfpgan = False
 have_gfpgan = False
 gfpgan_constructor = None
 gfpgan_constructor = None
 
 
-def setup_gfpgan():
-    try:
-        gfpgan_model_path()
 
 
-        if os.path.exists(cmd_opts.gfpgan_dir):
-            sys.path.append(os.path.abspath(cmd_opts.gfpgan_dir))
-        from gfpgan import GFPGANer
+def setup_model(dirname):
+    global model_path
+    if not os.path.exists(model_path):
+        os.makedirs(model_path)
 
 
+    try:
+        from modules.gfpgan_model_arch import GFPGANerr
+        global cmd_dir
         global have_gfpgan
         global have_gfpgan
-        have_gfpgan = True
-
         global gfpgan_constructor
         global gfpgan_constructor
-        gfpgan_constructor = GFPGANer
+
+        cmd_dir = dirname
+        have_gfpgan = True
+        gfpgan_constructor = GFPGANerr
 
 
         class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
         class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
             def name(self):
             def name(self):

+ 150 - 0
modules/gfpgan_model_arch.py

@@ -0,0 +1,150 @@
+# GFPGAN likes to download stuff "wherever", and we're trying to fix that, so this is a copy of the original...
+
+import cv2
+import os
+import torch
+from basicsr.utils import img2tensor, tensor2img
+from basicsr.utils.download_util import load_file_from_url
+from facexlib.utils.face_restoration_helper import FaceRestoreHelper
+from torchvision.transforms.functional import normalize
+
+from gfpgan.archs.gfpgan_bilinear_arch import GFPGANBilinear
+from gfpgan.archs.gfpganv1_arch import GFPGANv1
+from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
+
+ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+
+
+class GFPGANerr():
+    """Helper for restoration with GFPGAN.
+
+    It will detect and crop faces, and then resize the faces to 512x512.
+    GFPGAN is used to restored the resized faces.
+    The background is upsampled with the bg_upsampler.
+    Finally, the faces will be pasted back to the upsample background image.
+
+    Args:
+        model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
+        upscale (float): The upscale of the final output. Default: 2.
+        arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
+        channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
+        bg_upsampler (nn.Module): The upsampler for the background. Default: None.
+    """
+
+    def __init__(self, model_path, model_dir, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None, device=None):
+        self.upscale = upscale
+        self.bg_upsampler = bg_upsampler
+
+        # initialize model
+        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
+        # initialize the GFP-GAN
+        if arch == 'clean':
+            self.gfpgan = GFPGANv1Clean(
+                out_size=512,
+                num_style_feat=512,
+                channel_multiplier=channel_multiplier,
+                decoder_load_path=None,
+                fix_decoder=False,
+                num_mlp=8,
+                input_is_latent=True,
+                different_w=True,
+                narrow=1,
+                sft_half=True)
+        elif arch == 'bilinear':
+            self.gfpgan = GFPGANBilinear(
+                out_size=512,
+                num_style_feat=512,
+                channel_multiplier=channel_multiplier,
+                decoder_load_path=None,
+                fix_decoder=False,
+                num_mlp=8,
+                input_is_latent=True,
+                different_w=True,
+                narrow=1,
+                sft_half=True)
+        elif arch == 'original':
+            self.gfpgan = GFPGANv1(
+                out_size=512,
+                num_style_feat=512,
+                channel_multiplier=channel_multiplier,
+                decoder_load_path=None,
+                fix_decoder=True,
+                num_mlp=8,
+                input_is_latent=True,
+                different_w=True,
+                narrow=1,
+                sft_half=True)
+        elif arch == 'RestoreFormer':
+            from gfpgan.archs.restoreformer_arch import RestoreFormer
+            self.gfpgan = RestoreFormer()
+        # initialize face helper
+        self.face_helper = FaceRestoreHelper(
+            upscale,
+            face_size=512,
+            crop_ratio=(1, 1),
+            det_model='retinaface_resnet50',
+            save_ext='png',
+            use_parse=True,
+            device=self.device,
+            model_rootpath=model_dir)
+
+        if model_path.startswith('https://'):
+            model_path = load_file_from_url(
+                url=model_path, model_dir=model_dir, progress=True, file_name=None)
+        loadnet = torch.load(model_path)
+        if 'params_ema' in loadnet:
+            keyname = 'params_ema'
+        else:
+            keyname = 'params'
+        self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
+        self.gfpgan.eval()
+        self.gfpgan = self.gfpgan.to(self.device)
+
+    @torch.no_grad()
+    def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True, weight=0.5):
+        self.face_helper.clean_all()
+
+        if has_aligned:  # the inputs are already aligned
+            img = cv2.resize(img, (512, 512))
+            self.face_helper.cropped_faces = [img]
+        else:
+            self.face_helper.read_image(img)
+            # get face landmarks for each face
+            self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
+            # eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
+            # TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
+            # align and warp each face
+            self.face_helper.align_warp_face()
+
+        # face restoration
+        for cropped_face in self.face_helper.cropped_faces:
+            # prepare data
+            cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
+            normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+            cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
+
+            try:
+                output = self.gfpgan(cropped_face_t, return_rgb=False, weight=weight)[0]
+                # convert to image
+                restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
+            except RuntimeError as error:
+                print(f'\tFailed inference for GFPGAN: {error}.')
+                restored_face = cropped_face
+
+            restored_face = restored_face.astype('uint8')
+            self.face_helper.add_restored_face(restored_face)
+
+        if not has_aligned and paste_back:
+            # upsample the background
+            if self.bg_upsampler is not None:
+                # Now only support RealESRGAN for upsampling background
+                bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
+            else:
+                bg_img = None
+
+            self.face_helper.get_inverse_affine(None)
+            # paste each restored face to the input image
+            restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)
+            return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
+        else:
+            return self.face_helper.cropped_faces, self.face_helper.restored_faces, None

+ 26 - 19
modules/ldsr_model.py

@@ -3,11 +3,14 @@ import sys
 import traceback
 import traceback
 from collections import namedtuple
 from collections import namedtuple
 
 
-from basicsr.utils.download_util import load_file_from_url
+from modules import shared, images, modelloader, paths
+from modules.paths import models_path
 
 
-import modules.images
-from modules import shared
-from modules.paths import script_path
+model_dir = "LDSR"
+model_path = os.path.join(models_path, model_dir)
+cmd_path = None
+model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
+yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
 
 
 LDSRModelInfo = namedtuple("LDSRModelInfo", ["name", "location", "model", "netscale"])
 LDSRModelInfo = namedtuple("LDSRModelInfo", ["name", "location", "model", "netscale"])
 
 
@@ -25,28 +28,32 @@ class UpscalerLDSR(modules.images.Upscaler):
         return upscale_with_ldsr(img)
         return upscale_with_ldsr(img)
 
 
 
 
-def add_lsdr():
-    modules.shared.sd_upscalers.append(UpscalerLDSR(100))
+def setup_model(dirname):
+    global cmd_path
+    global model_path
+    if not os.path.exists(model_path):
+        os.makedirs(model_path)
+    cmd_path = dirname
+    shared.sd_upscalers.append(UpscalerLDSR(100))
 
 
 
 
-def setup_ldsr():
-    path = modules.paths.paths.get("LDSR", None)
+def prepare_ldsr():
+    path = paths.paths.get("LDSR", None)
     if path is None:
     if path is None:
         return
         return
     global have_ldsr
     global have_ldsr
     global LDSR_obj
     global LDSR_obj
     try:
     try:
         from LDSR import LDSR
         from LDSR import LDSR
-        model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
-        yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
-        repo_path = 'latent-diffusion/experiments/pretrained_models/'
-        model_path = load_file_from_url(url=model_url, model_dir=os.path.join("repositories", repo_path),
-                                        progress=True, file_name="model.chkpt")
-        yaml_path = load_file_from_url(url=yaml_url, model_dir=os.path.join("repositories", repo_path),
-                                       progress=True, file_name="project.yaml")
-        have_ldsr = True
-        LDSR_obj = LDSR(model_path, yaml_path)
-
+        model_files = modelloader.load_models(model_path, model_url, cmd_path, dl_name="model.ckpt", ext_filter=[".ckpt"])
+        yaml_files = modelloader.load_models(model_path, yaml_url, cmd_path, dl_name="project.yaml", ext_filter=[".yaml"])
+        if len(model_files) != 0 and len(yaml_files) != 0:
+            model_file = model_files[0]
+            yaml_file = yaml_files[0]
+            have_ldsr = True
+            LDSR_obj = LDSR(model_file, yaml_file)
+        else:
+            return
 
 
     except Exception:
     except Exception:
         print("Error importing LDSR:", file=sys.stderr)
         print("Error importing LDSR:", file=sys.stderr)
@@ -55,7 +62,7 @@ def setup_ldsr():
 
 
 
 
 def upscale_with_ldsr(image):
 def upscale_with_ldsr(image):
-    setup_ldsr()
+    prepare_ldsr()
     if not have_ldsr or LDSR_obj is None:
     if not have_ldsr or LDSR_obj is None:
         return image
         return image
 
 

+ 65 - 0
modules/modelloader.py

@@ -0,0 +1,65 @@
+import os
+from urllib.parse import urlparse
+
+from basicsr.utils.download_util import load_file_from_url
+
+
+def load_models(model_path: str, model_url: str = None, command_path: str = None, dl_name: str = None, existing=None,
+                ext_filter=None) -> list:
+    """
+    A one-and done loader to try finding the desired models in specified directories.
+
+    @param dl_name: The file name to use for downloading a model. If not specified, it will be used from the URL.
+    @param model_url: If specified, attempt to download model from the given URL.
+    @param model_path: The location to store/find models in.
+    @param command_path: A command-line argument to search for models in first.
+    @param existing: An array of existing model paths.
+    @param ext_filter: An optional list of filename extensions to filter by
+    @return: A list of paths containing the desired model(s)
+    """
+    if ext_filter is None:
+        ext_filter = []
+    if existing is None:
+        existing = []
+    try:
+        places = []
+        if command_path is not None and command_path != model_path:
+            pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
+            if os.path.exists(pretrained_path):
+                places.append(pretrained_path)
+            elif os.path.exists(command_path):
+                places.append(command_path)
+        places.append(model_path)
+        for place in places:
+            if os.path.exists(place):
+                for file in os.listdir(place):
+                    if os.path.isdir(file):
+                        continue
+                    if len(ext_filter) != 0:
+                        model_name, extension = os.path.splitext(file)
+                        if extension not in ext_filter:
+                            continue
+                    if file not in existing:
+                        path = os.path.join(place, file)
+                        existing.append(path)
+        if model_url is not None:
+            if dl_name is not None:
+                model_file = load_file_from_url(url=model_url, model_dir=model_path, file_name=dl_name, progress=True)
+            else:
+                model_file = load_file_from_url(url=model_url, model_dir=model_path, progress=True)
+
+            if os.path.exists(model_file) and os.path.isfile(model_file) and model_file not in existing:
+                existing.append(model_file)
+    except:
+        pass
+    return existing
+
+
+def friendly_name(file: str):
+    if "http" in file:
+        file = urlparse(file).path
+
+    file = os.path.basename(file)
+    model_name, extension = os.path.splitext(file)
+    model_name = model_name.replace("_", " ").title()
+    return model_name

+ 2 - 1
modules/paths.py

@@ -3,9 +3,10 @@ import os
 import sys
 import sys
 
 
 script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
 script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+models_path = os.path.join(script_path, "models")
 sys.path.insert(0, script_path)
 sys.path.insert(0, script_path)
 
 
-# search for directory of stable diffsuion in following palces
+# search for directory of stable diffusion in following places
 sd_path = None
 sd_path = None
 possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)]
 possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)]
 for possible_sd_path in possible_sd_paths:
 for possible_sd_path in possible_sd_paths:

+ 19 - 4
modules/realesrgan_model.py

@@ -1,14 +1,20 @@
+import os
 import sys
 import sys
 import traceback
 import traceback
 from collections import namedtuple
 from collections import namedtuple
 
 
 import numpy as np
 import numpy as np
 from PIL import Image
 from PIL import Image
+from basicsr.utils.download_util import load_file_from_url
 from realesrgan import RealESRGANer
 from realesrgan import RealESRGANer
 
 
 import modules.images
 import modules.images
+from modules.paths import models_path
 from modules.shared import cmd_opts, opts
 from modules.shared import cmd_opts, opts
 
 
+model_dir = "RealESRGAN"
+model_path = os.path.join(models_path, model_dir)
+cmd_dir = None
 RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"])
 RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"])
 realesrgan_models = []
 realesrgan_models = []
 have_realesrgan = False
 have_realesrgan = False
@@ -17,7 +23,6 @@ have_realesrgan = False
 def get_realesrgan_models():
 def get_realesrgan_models():
     try:
     try:
         from basicsr.archs.rrdbnet_arch import RRDBNet
         from basicsr.archs.rrdbnet_arch import RRDBNet
-        from realesrgan import RealESRGANer
         from realesrgan.archs.srvgg_arch import SRVGGNetCompact
         from realesrgan.archs.srvgg_arch import SRVGGNetCompact
         models = [
         models = [
             RealesrganModelInfo(
             RealesrganModelInfo(
@@ -59,7 +64,7 @@ def get_realesrgan_models():
         ]
         ]
         return models
         return models
     except Exception as e:
     except Exception as e:
-        print("Error makeing Real-ESRGAN midels list:", file=sys.stderr)
+        print("Error making Real-ESRGAN models list:", file=sys.stderr)
         print(traceback.format_exc(), file=sys.stderr)
         print(traceback.format_exc(), file=sys.stderr)
 
 
 
 
@@ -73,10 +78,15 @@ class UpscalerRealESRGAN(modules.images.Upscaler):
         return upscale_with_realesrgan(img, self.upscaling, self.model_index)
         return upscale_with_realesrgan(img, self.upscaling, self.model_index)
 
 
 
 
-def setup_realesrgan():
+def setup_model(dirname):
+    global model_path
+    if not os.path.exists(model_path):
+        os.makedirs(model_path)
+
     global realesrgan_models
     global realesrgan_models
     global have_realesrgan
     global have_realesrgan
-
+    if model_path != dirname:
+        model_path = dirname
     try:
     try:
         from basicsr.archs.rrdbnet_arch import RRDBNet
         from basicsr.archs.rrdbnet_arch import RRDBNet
         from realesrgan import RealESRGANer
         from realesrgan import RealESRGANer
@@ -104,6 +114,11 @@ def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index)
     info = realesrgan_models[RealESRGAN_model_index]
     info = realesrgan_models[RealESRGAN_model_index]
 
 
     model = info.model()
     model = info.model()
+    model_file = load_file_from_url(url=info.location, model_dir=model_path, progress=True)
+    if not os.path.exists(model_file):
+        print("Unable to load RealESRGAN model: %s" % info.name)
+        return image
+
     upsampler = RealESRGANer(
     upsampler = RealESRGANer(
         scale=info.netscale,
         scale=info.netscale,
         model_path=info.location,
         model_path=info.location,

+ 8 - 4
modules/shared.py

@@ -16,11 +16,11 @@ import modules.sd_models
 
 
 sd_model_file = os.path.join(script_path, 'model.ckpt')
 sd_model_file = os.path.join(script_path, 'model.ckpt')
 default_sd_model_file = sd_model_file
 default_sd_model_file = sd_model_file
-
+model_path = os.path.join(script_path, 'models')
 parser = argparse.ArgumentParser()
 parser = argparse.ArgumentParser()
 parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
 parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
 parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",)
 parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",)
-parser.add_argument("--ckpt-dir", type=str, default=os.path.join(script_path, 'models'), help="path to directory with stable diffusion checkpoints",)
+parser.add_argument("--ckpt-dir", type=str, default=model_path, help="path to directory with stable diffusion checkpoints",)
 parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
 parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
 parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
 parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
 parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
 parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
@@ -34,8 +34,12 @@ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="dis
 parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
 parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
 parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
 parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
 parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
 parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
-parser.add_argument("--esrgan-models-path", type=str, help="path to directory with ESRGAN models", default=os.path.join(script_path, 'ESRGAN'))
-parser.add_argument("--swinir-models-path", type=str, help="path to directory with SwinIR models", default=os.path.join(script_path, 'SwinIR'))
+parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model(s)", default=os.path.join(model_path, 'Codeformer'))
+parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model(s)", default=os.path.join(model_path, 'GFPGAN'))
+parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN models", default=os.path.join(model_path, 'ESRGAN'))
+parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN models", default=os.path.join(model_path, 'RealESRGAN'))
+parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR models", default=os.path.join(model_path, 'SwinIR'))
+parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR models", default=os.path.join(model_path, 'LDSR'))
 parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
 parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
 parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
 parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
 parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
 parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")

+ 55 - 20
modules/swinir_model.py

@@ -1,21 +1,39 @@
+import contextlib
+import os
 import sys
 import sys
 import traceback
 import traceback
-import cv2
-import os
-import contextlib
+
 import numpy as np
 import numpy as np
-from PIL import Image
 import torch
 import torch
+from PIL import Image
+from basicsr.utils.download_util import load_file_from_url
+
 import modules.images
 import modules.images
+from modules import modelloader
+from modules.paths import models_path
 from modules.shared import cmd_opts, opts, device
 from modules.shared import cmd_opts, opts, device
-from modules.swinir_arch import SwinIR as net
+from modules.swinir_model_arch import SwinIR as net
 
 
+model_dir = "SwinIR"
+model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
+model_name = "SwinIR x4"
+model_path = os.path.join(models_path, model_dir)
+cmd_path = ""
 precision_scope = (
 precision_scope = (
     torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
     torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
 )
 )
 
 
 
 
-def load_model(filename, scale=4):
+def load_model(path, scale=4):
+    global model_path
+    global model_name
+    if "http" in path:
+        dl_name = "%s%s" % (model_name.replace(" ", "_"), ".pth")
+        filename = load_file_from_url(url=path, model_dir=model_path, file_name=dl_name, progress=True)
+    else:
+        filename = path
+    if filename is None or not os.path.exists(filename):
+        return None
     model = net(
     model = net(
         upscale=scale,
         upscale=scale,
         in_chans=3,
         in_chans=3,
@@ -37,19 +55,29 @@ def load_model(filename, scale=4):
     return model
     return model
 
 
 
 
-def load_models(dirname):
-    for file in os.listdir(dirname):
-        path = os.path.join(dirname, file)
-        model_name, extension = os.path.splitext(file)
+def setup_model(dirname):
+    global model_path
+    global model_name
+    global cmd_path
+    if not os.path.exists(model_path):
+        os.makedirs(model_path)
+    cmd_path = dirname
+    model_file = ""
+    try:
+        models = modelloader.load_models(model_path, ext_filter=[".pt", ".pth"], command_path=cmd_path)
 
 
-        if extension != ".pt" and extension != ".pth":
-            continue
+        if len(models) != 0:
+            model_file = models[0]
+            name = modelloader.friendly_name(model_file)
+        else:
+            # Add the "default" model if none are found.
+            model_file = model_url
+            name = model_name
 
 
-        try:
-            modules.shared.sd_upscalers.append(UpscalerSwin(path, model_name))
-        except Exception:
-            print(f"Error loading SwinIR model: {path}", file=sys.stderr)
-            print(traceback.format_exc(), file=sys.stderr)
+        modules.shared.sd_upscalers.append(UpscalerSwin(model_file, name))
+    except Exception:
+        print(f"Error loading SwinIR model: {model_file}", file=sys.stderr)
+        print(traceback.format_exc(), file=sys.stderr)
 
 
 
 
 def upscale(
 def upscale(
@@ -115,9 +143,16 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
 class UpscalerSwin(modules.images.Upscaler):
 class UpscalerSwin(modules.images.Upscaler):
     def __init__(self, filename, title):
     def __init__(self, filename, title):
         self.name = title
         self.name = title
-        self.model = load_model(filename)
+        self.filename = filename
 
 
     def do_upscale(self, img):
     def do_upscale(self, img):
-        model = self.model.to(device)
+        model = load_model(self.filename)
+        if model is None:
+            return img
+        model = model.to(device)
         img = upscale(img, model)
         img = upscale(img, model)
-        return img
+        try:
+            torch.cuda.empty_cache()
+        except:
+            pass
+        return img

+ 21 - 24
webui.py

@@ -1,37 +1,34 @@
 import os
 import os
-import threading
-
-from modules.paths import script_path
-
 import signal
 import signal
+import threading
 
 
-from modules.shared import opts, cmd_opts, state
-import modules.shared as shared
-import modules.ui
-import modules.scripts
-import modules.sd_hijack
-import modules.codeformer_model
-import modules.gfpgan_model
-import modules.face_restoration
-import modules.realesrgan_model as realesrgan
+import modules.codeformer_model as codeformer
 import modules.esrgan_model as esrgan
 import modules.esrgan_model as esrgan
-import modules.ldsr_model as ldsr
 import modules.extras
 import modules.extras
-import modules.lowvram
-import modules.txt2img
+import modules.face_restoration
+import modules.gfpgan_model as gfpgan
 import modules.img2img
 import modules.img2img
-import modules.swinir as swinir
+import modules.ldsr_model as ldsr
+import modules.lowvram
+import modules.realesrgan_model as realesrgan
+import modules.scripts
+import modules.sd_hijack
 import modules.sd_models
 import modules.sd_models
+import modules.shared as shared
+import modules.swinir_model as swinir
+import modules.txt2img
+import modules.ui
+from modules.paths import script_path
+from modules.shared import cmd_opts
 
 
-
-modules.codeformer_model.setup_codeformer()
-modules.gfpgan_model.setup_gfpgan()
+codeformer.setup_model(cmd_opts.codeformer_models_path)
+gfpgan.setup_model(cmd_opts.gfpgan_models_path)
 shared.face_restorers.append(modules.face_restoration.FaceRestoration())
 shared.face_restorers.append(modules.face_restoration.FaceRestoration())
 
 
-esrgan.load_models(cmd_opts.esrgan_models_path)
-swinir.load_models(cmd_opts.swinir_models_path)
-realesrgan.setup_realesrgan()
-ldsr.add_lsdr()
+esrgan.setup_model(cmd_opts.esrgan_models_path)
+swinir.setup_model(cmd_opts.swinir_models_path)
+realesrgan.setup_model(cmd_opts.realesrgan_models_path)
+ldsr.setup_model(cmd_opts.ldsr_models_path)
 queue_lock = threading.Lock()
 queue_lock = threading.Lock()