Browse Source

use_checkpoint = False

huchenlei 1 year ago
parent
commit
022d835565

+ 1 - 1
configs/alt-diffusion-inference.yaml

@@ -40,7 +40,7 @@ model:
         use_spatial_transformer: True
         use_spatial_transformer: True
         transformer_depth: 1
         transformer_depth: 1
         context_dim: 768
         context_dim: 768
-        use_checkpoint: True
+        use_checkpoint: False
         legacy: False
         legacy: False
 
 
     first_stage_config:
     first_stage_config:

+ 1 - 1
configs/alt-diffusion-m18-inference.yaml

@@ -41,7 +41,7 @@ model:
         use_linear_in_transformer: True
         use_linear_in_transformer: True
         transformer_depth: 1
         transformer_depth: 1
         context_dim: 1024
         context_dim: 1024
-        use_checkpoint: True
+        use_checkpoint: False
         legacy: False
         legacy: False
 
 
     first_stage_config:
     first_stage_config:

+ 1 - 1
configs/instruct-pix2pix.yaml

@@ -45,7 +45,7 @@ model:
         use_spatial_transformer: True
         use_spatial_transformer: True
         transformer_depth: 1
         transformer_depth: 1
         context_dim: 768
         context_dim: 768
-        use_checkpoint: True
+        use_checkpoint: False
         legacy: False
         legacy: False
 
 
     first_stage_config:
     first_stage_config:

+ 1 - 1
configs/sd_xl_inpaint.yaml

@@ -21,7 +21,7 @@ model:
       params:
       params:
         adm_in_channels: 2816
         adm_in_channels: 2816
         num_classes: sequential
         num_classes: sequential
-        use_checkpoint: True
+        use_checkpoint: False
         in_channels: 9
         in_channels: 9
         out_channels: 4
         out_channels: 4
         model_channels: 320
         model_channels: 320

+ 1 - 1
configs/v1-inference.yaml

@@ -40,7 +40,7 @@ model:
         use_spatial_transformer: True
         use_spatial_transformer: True
         transformer_depth: 1
         transformer_depth: 1
         context_dim: 768
         context_dim: 768
-        use_checkpoint: True
+        use_checkpoint: False
         legacy: False
         legacy: False
 
 
     first_stage_config:
     first_stage_config:

+ 1 - 1
configs/v1-inpainting-inference.yaml

@@ -40,7 +40,7 @@ model:
         use_spatial_transformer: True
         use_spatial_transformer: True
         transformer_depth: 1
         transformer_depth: 1
         context_dim: 768
         context_dim: 768
-        use_checkpoint: True
+        use_checkpoint: False
         legacy: False
         legacy: False
 
 
     first_stage_config:
     first_stage_config:

+ 6 - 3
modules/sd_hijack_checkpoint.py

@@ -4,16 +4,19 @@ import ldm.modules.attention
 import ldm.modules.diffusionmodules.openaimodel
 import ldm.modules.diffusionmodules.openaimodel
 
 
 
 
+# Setting flag=False so that torch skips checking parameters.
+# parameters checking is expensive in frequent operations.
+
 def BasicTransformerBlock_forward(self, x, context=None):
 def BasicTransformerBlock_forward(self, x, context=None):
-    return checkpoint(self._forward, x, context)
+    return checkpoint(self._forward, x, context, flag=False)
 
 
 
 
 def AttentionBlock_forward(self, x):
 def AttentionBlock_forward(self, x):
-    return checkpoint(self._forward, x)
+    return checkpoint(self._forward, x, flag=False)
 
 
 
 
 def ResBlock_forward(self, x, emb):
 def ResBlock_forward(self, x, emb):
-    return checkpoint(self._forward, x, emb)
+    return checkpoint(self._forward, x, emb, flag=False)
 
 
 
 
 stored = []
 stored = []

+ 1 - 1
modules/sd_models_config.py

@@ -35,7 +35,7 @@ def is_using_v_parameterization_for_sd2(state_dict):
 
 
     with sd_disable_initialization.DisableInitialization():
     with sd_disable_initialization.DisableInitialization():
         unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
         unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
-            use_checkpoint=True,
+            use_checkpoint=False,
             use_fp16=False,
             use_fp16=False,
             image_size=32,
             image_size=32,
             in_channels=4,
             in_channels=4,