Browse Source

gfpgan: just download the damn model

AUTOMATIC 2 years ago
parent
commit
d4205e66fa
2 changed files with 14 additions and 8 deletions
  1. 13 6
      modules/gfpgan_model.py
  2. 1 2
      modules/shared.py

+ 13 - 6
modules/gfpgan_model.py

@@ -1,6 +1,7 @@
 import os
 import sys
 import traceback
+from glob import glob
 
 from modules import shared, devices
 from modules.shared import cmd_opts
@@ -11,14 +12,20 @@ import modules.face_restoration
 def gfpgan_model_path():
     from modules.shared import cmd_opts
 
+    filemask = 'GFPGAN*.pth'
+
+    if cmd_opts.gfpgan_model is not None:
+        return cmd_opts.gfpgan_model
+
     places = [script_path, '.', os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models')]
-    files = [cmd_opts.gfpgan_model] + [os.path.join(dirname, cmd_opts.gfpgan_model) for dirname in places]
-    found = [x for x in files if os.path.exists(x)]
 
-    if len(found) == 0:
-        raise Exception("GFPGAN model not found in paths: " + ", ".join(files))
+    filename = None
+    for place in places:
+        filename = next(iter(glob(os.path.join(place, filemask))), None)
+        if filename is not None:
+            break
 
-    return found[0]
+    return filename
 
 
 loaded_gfpgan_model = None
@@ -34,7 +41,7 @@ def gfpgan():
     if gfpgan_constructor is None:
         return None
 
-    model = gfpgan_constructor(model_path=gfpgan_model_path(), upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
+    model = gfpgan_constructor(model_path=gfpgan_model_path() or 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth', upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
     model.gfpgan.to(shared.device)
     loaded_gfpgan_model = model
 

+ 1 - 2
modules/shared.py

@@ -2,7 +2,6 @@ import sys
 import argparse
 import json
 import os
-from glob import glob
 import gradio as gr
 import tqdm
 
@@ -22,7 +21,7 @@ parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs
 parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",)
 parser.add_argument("--ckpt-dir", type=str, default=os.path.join(script_path, 'models'), help="path to directory with stable diffusion checkpoints",)
 parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
-parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=next(iter(glob('GFPGAN*.pth')), ''))
+parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
 parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
 parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
 parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")