Pārlūkot izejas kodu

Add --precision half cmd option

huchenlei 1 gadu atpakaļ
vecāks
revīzija
2a8a60c2c5

+ 1 - 1
modules/cmd_args.py

@@ -41,7 +41,7 @@ parser.add_argument("--lowvram", action='store_true', help="enable stable diffus
 parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
 parser.add_argument("--always-batch-cond-uncond", action='store_true', help="does not do anything")
 parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
-parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
+parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "half", "autocast"], default="autocast")
 parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
 parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
 parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)

+ 24 - 0
modules/devices.py

@@ -114,6 +114,9 @@ errors.run(enable_tf32, "Enabling TF32")
 
 cpu: torch.device = torch.device("cpu")
 fp8: bool = False
+# Force fp16 for all models in inference. No casting during inference.
+# This flag is controlled by "--precision half" command line arg.
+force_fp16: bool = False
 device: torch.device = None
 device_interrogate: torch.device = None
 device_gfpgan: torch.device = None
@@ -127,6 +130,8 @@ unet_needs_upcast = False
 
 
 def cond_cast_unet(input):
+    if force_fp16:
+        return input.to(torch.float16)
     return input.to(dtype_unet) if unet_needs_upcast else input
 
 
@@ -206,6 +211,11 @@ def autocast(disable=False):
     if disable:
         return contextlib.nullcontext()
 
+    if force_fp16:
+        # No casting during inference if force_fp16 is enabled.
+        # All tensor dtype conversion happens before inference.
+        return contextlib.nullcontext()
+
     if fp8 and device==cpu:
         return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
 
@@ -269,3 +279,17 @@ def first_time_calculation():
     x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
     conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
     conv2d(x)
+
+
+def force_model_fp16():
+    """
+    ldm and sgm has modules.diffusionmodules.util.GroupNorm32.forward, which
+    force conversion of input to float32. If force_fp16 is enabled, we need to
+    prevent this casting.
+    """
+    assert force_fp16
+    import sgm.modules.diffusionmodules.util as sgm_util
+    import ldm.modules.diffusionmodules.util as ldm_util
+    sgm_util.GroupNorm32 = torch.nn.GroupNorm
+    ldm_util.GroupNorm32 = torch.nn.GroupNorm
+    print("ldm/sgm GroupNorm32 replaced with normal torch.nn.GroupNorm due to `--precision half`.")

+ 22 - 7
modules/sd_hijack_unet.py

@@ -36,7 +36,7 @@ th = TorchHijackForUnet()
 
 # Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
 def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
-
+    """Always make sure inputs to unet are in correct dtype."""
     if isinstance(cond, dict):
         for y in cond.keys():
             if isinstance(cond[y], list):
@@ -45,7 +45,11 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
                 cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]
 
     with devices.autocast():
-        return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
+        result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs)
+        if devices.unet_needs_upcast:
+            return result.float()
+        else:
+            return result
 
 
 class GELUHijack(torch.nn.GELU, torch.nn.Module):
@@ -64,12 +68,11 @@ def hijack_ddpm_edit():
     if not ddpm_edit_hijack:
         CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
         CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
-        ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
+        ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model)
 
 
 unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
-CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
-CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
+
 if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
     CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
     CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
@@ -81,5 +84,17 @@ CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_s
 CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
 CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
 
-CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model, unet_needs_upcast)
-CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
+CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model)
+CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model)
+
+
+def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs):
+    if devices.unet_needs_upcast and timesteps.dtype == torch.int64:
+        dtype = torch.float32
+    else:
+        dtype = devices.dtype_unet
+    return orig_func(timesteps, *args, **kwargs).to(dtype=dtype)
+
+
+CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
+CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)

+ 15 - 11
modules/sd_hijack_utils.py

@@ -1,7 +1,11 @@
 import importlib
 
+
+always_true_func = lambda *args, **kwargs: True
+
+
 class CondFunc:
-    def __new__(cls, orig_func, sub_func, cond_func):
+    def __new__(cls, orig_func, sub_func, cond_func=always_true_func):
         self = super(CondFunc, cls).__new__(cls)
         if isinstance(orig_func, str):
             func_path = orig_func.split('.')
@@ -20,13 +24,13 @@ class CondFunc:
                 print(f"Warning: Failed to resolve {orig_func} for CondFunc hijack")
                 pass
         self.__init__(orig_func, sub_func, cond_func)
-        return lambda *args, **kwargs: self(*args, **kwargs)
-    def __init__(self, orig_func, sub_func, cond_func):
-        self.__orig_func = orig_func
-        self.__sub_func = sub_func
-        self.__cond_func = cond_func
-    def __call__(self, *args, **kwargs):
-        if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
-            return self.__sub_func(self.__orig_func, *args, **kwargs)
-        else:
-            return self.__orig_func(*args, **kwargs)
+        return lambda *args, **kwargs: self(*args, **kwargs)
+    def __init__(self, orig_func, sub_func, cond_func):
+        self.__orig_func = orig_func
+        self.__sub_func = sub_func
+        self.__cond_func = cond_func
+    def __call__(self, *args, **kwargs):
+        if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
+            return self.__sub_func(self.__orig_func, *args, **kwargs)
+        else:
+            return self.__orig_func(*args, **kwargs)

+ 1 - 0
modules/sd_models.py

@@ -403,6 +403,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
         model.float()
         model.alphas_cumprod_original = model.alphas_cumprod
         devices.dtype_unet = torch.float32
+        assert shared.cmd_opts.precision != "half", "Cannot use --precision half with --no-half"
         timer.record("apply float()")
     else:
         vae = model.first_stage_model

+ 8 - 0
modules/shared_init.py

@@ -31,6 +31,14 @@ def initialize():
     devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
     devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype
 
+    if cmd_opts.precision == "half":
+        msg = "--no-half and --no-half-vae conflict with --precision half"
+        assert devices.dtype == torch.float16, msg
+        assert devices.dtype_vae == torch.float16, msg
+        assert devices.dtype_inference == torch.float16, msg
+        devices.force_fp16 = True
+        devices.force_model_fp16()
+
     shared.device = devices.device
     shared.weight_load_location = None if cmd_opts.lowram else "cpu"