Browse Source

send all three of GFPGAN's and codeformer's models to CPU memory instead of just one for #1283

AUTOMATIC 2 years ago
parent
commit
6c6ae28bf5
4 changed files with 41 additions and 11 deletions
  1. 10 2
      modules/codeformer_model.py
  2. 10 0
      modules/devices.py
  3. 12 2
      modules/gfpgan_model.py
  4. 9 7
      modules/processing.py

+ 10 - 2
modules/codeformer_model.py

@@ -69,10 +69,14 @@ def setup_model(dirname):
 
 
                 self.net = net
                 self.net = net
                 self.face_helper = face_helper
                 self.face_helper = face_helper
-                self.net.to(devices.device_codeformer)
 
 
                 return net, face_helper
                 return net, face_helper
 
 
+            def send_model_to(self, device):
+                self.net.to(device)
+                self.face_helper.face_det.to(device)
+                self.face_helper.face_parse.to(device)
+
             def restore(self, np_image, w=None):
             def restore(self, np_image, w=None):
                 np_image = np_image[:, :, ::-1]
                 np_image = np_image[:, :, ::-1]
 
 
@@ -82,6 +86,8 @@ def setup_model(dirname):
                 if self.net is None or self.face_helper is None:
                 if self.net is None or self.face_helper is None:
                     return np_image
                     return np_image
 
 
+                self.send_model_to(devices.device_codeformer)
+
                 self.face_helper.clean_all()
                 self.face_helper.clean_all()
                 self.face_helper.read_image(np_image)
                 self.face_helper.read_image(np_image)
                 self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
                 self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
@@ -113,8 +119,10 @@ def setup_model(dirname):
                 if original_resolution != restored_img.shape[0:2]:
                 if original_resolution != restored_img.shape[0:2]:
                     restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
                     restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
 
 
+                self.face_helper.clean_all()
+
                 if shared.opts.face_restoration_unload:
                 if shared.opts.face_restoration_unload:
-                    self.net.to(devices.cpu)
+                    self.send_model_to(devices.cpu)
 
 
                 return restored_img
                 return restored_img
 
 

+ 10 - 0
modules/devices.py

@@ -1,3 +1,5 @@
+import contextlib
+
 import torch
 import torch
 
 
 # has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
 # has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
@@ -57,3 +59,11 @@ def randn_without_seed(shape):
 
 
     return torch.randn(shape, device=device)
     return torch.randn(shape, device=device)
 
 
+
+def autocast():
+    from modules import shared
+
+    if dtype == torch.float32 or shared.cmd_opts.precision == "full":
+        return contextlib.nullcontext()
+
+    return torch.autocast("cuda")

+ 12 - 2
modules/gfpgan_model.py

@@ -37,22 +37,32 @@ def gfpgann():
         print("Unable to load gfpgan model!")
         print("Unable to load gfpgan model!")
         return None
         return None
     model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
     model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
-    model.gfpgan.to(shared.device)
     loaded_gfpgan_model = model
     loaded_gfpgan_model = model
 
 
     return model
     return model
 
 
 
 
+def send_model_to(model, device):
+    model.gfpgan.to(device)
+    model.face_helper.face_det.to(device)
+    model.face_helper.face_parse.to(device)
+
+
 def gfpgan_fix_faces(np_image):
 def gfpgan_fix_faces(np_image):
     model = gfpgann()
     model = gfpgann()
     if model is None:
     if model is None:
         return np_image
         return np_image
+
+    send_model_to(model, devices.device)
+
     np_image_bgr = np_image[:, :, ::-1]
     np_image_bgr = np_image[:, :, ::-1]
     cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
     cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
     np_image = gfpgan_output_bgr[:, :, ::-1]
     np_image = gfpgan_output_bgr[:, :, ::-1]
 
 
+    model.face_helper.clean_all()
+
     if shared.opts.face_restoration_unload:
     if shared.opts.face_restoration_unload:
-        model.gfpgan.to(devices.cpu)
+        send_model_to(model, devices.cpu)
 
 
     return np_image
     return np_image
 
 

+ 9 - 7
modules/processing.py

@@ -1,4 +1,3 @@
-import contextlib
 import json
 import json
 import math
 import math
 import os
 import os
@@ -330,9 +329,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
 
 
     infotexts = []
     infotexts = []
     output_images = []
     output_images = []
-    precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
-    ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
-    with torch.no_grad(), precision_scope("cuda"), ema_scope():
+
+    with torch.no_grad():
         p.init(all_prompts, all_seeds, all_subseeds)
         p.init(all_prompts, all_seeds, all_subseeds)
 
 
         if state.job_count == -1:
         if state.job_count == -1:
@@ -351,8 +349,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
 
 
             #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
             #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
             #c = p.sd_model.get_learned_conditioning(prompts)
             #c = p.sd_model.get_learned_conditioning(prompts)
-            uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps)
-            c = prompt_parser.get_learned_conditioning(prompts, p.steps)
+            with devices.autocast():
+                uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps)
+                c = prompt_parser.get_learned_conditioning(prompts, p.steps)
 
 
             if len(model_hijack.comments) > 0:
             if len(model_hijack.comments) > 0:
                 for comment in model_hijack.comments:
                 for comment in model_hijack.comments:
@@ -361,7 +360,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
             if p.n_iter > 1:
             if p.n_iter > 1:
                 shared.state.job = f"Batch {n+1} out of {p.n_iter}"
                 shared.state.job = f"Batch {n+1} out of {p.n_iter}"
 
 
-            samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
+            with devices.autocast():
+                samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength).to(devices.dtype)
+
             if state.interrupted:
             if state.interrupted:
 
 
                 # if we are interruped, sample returns just noise
                 # if we are interruped, sample returns just noise
@@ -386,6 +387,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
                     devices.torch_gc()
                     devices.torch_gc()
 
 
                     x_sample = modules.face_restoration.restore_faces(x_sample)
                     x_sample = modules.face_restoration.restore_faces(x_sample)
+                    devices.torch_gc()
 
 
                 image = Image.fromarray(x_sample)
                 image = Image.fromarray(x_sample)