Browse Source

medvram support for SD3

AUTOMATIC1111 1 năm trước cách đây
mục cha
commit
a8fba9af35

+ 23 - 5
modules/lowvram.py

@@ -1,9 +1,12 @@
+from collections import namedtuple
+
 import torch
 import torch
 from modules import devices, shared
 from modules import devices, shared
 
 
 module_in_gpu = None
 module_in_gpu = None
 cpu = torch.device("cpu")
 cpu = torch.device("cpu")
 
 
+ModuleWithParent = namedtuple('ModuleWithParent', ['module', 'parent'], defaults=['None'])
 
 
 def send_everything_to_cpu():
 def send_everything_to_cpu():
     global module_in_gpu
     global module_in_gpu
@@ -75,13 +78,14 @@ def setup_for_low_vram(sd_model, use_medvram):
         (sd_model, 'depth_model'),
         (sd_model, 'depth_model'),
         (sd_model, 'embedder'),
         (sd_model, 'embedder'),
         (sd_model, 'model'),
         (sd_model, 'model'),
-        (sd_model, 'embedder'),
     ]
     ]
 
 
     is_sdxl = hasattr(sd_model, 'conditioner')
     is_sdxl = hasattr(sd_model, 'conditioner')
     is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
     is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
 
 
-    if is_sdxl:
+    if hasattr(sd_model, 'medvram_fields'):
+        to_remain_in_cpu = sd_model.medvram_fields()
+    elif is_sdxl:
         to_remain_in_cpu.append((sd_model, 'conditioner'))
         to_remain_in_cpu.append((sd_model, 'conditioner'))
     elif is_sd2:
     elif is_sd2:
         to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
         to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
@@ -103,7 +107,21 @@ def setup_for_low_vram(sd_model, use_medvram):
         setattr(obj, field, module)
         setattr(obj, field, module)
 
 
     # register hooks for those the first three models
     # register hooks for those the first three models
-    if is_sdxl:
+    if hasattr(sd_model.cond_stage_model, "medvram_modules"):
+        for module in sd_model.cond_stage_model.medvram_modules():
+            if isinstance(module, ModuleWithParent):
+                parent = module.parent
+                module = module.module
+            else:
+                parent = None
+
+            if module:
+                module.register_forward_pre_hook(send_me_to_gpu)
+
+                if parent:
+                    parents[module] = parent
+
+    elif is_sdxl:
         sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
         sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
     elif is_sd2:
     elif is_sd2:
         sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
         sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
@@ -117,9 +135,9 @@ def setup_for_low_vram(sd_model, use_medvram):
     sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
     sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
     sd_model.first_stage_model.encode = first_stage_model_encode_wrap
     sd_model.first_stage_model.encode = first_stage_model_encode_wrap
     sd_model.first_stage_model.decode = first_stage_model_decode_wrap
     sd_model.first_stage_model.decode = first_stage_model_decode_wrap
-    if sd_model.depth_model:
+    if hasattr(sd_model, 'depth_model'):
         sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
         sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
-    if sd_model.embedder:
+    if hasattr(sd_model, 'embedder'):
         sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
         sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
 
 
     if use_medvram:
     if use_medvram:

+ 0 - 1
modules/models/sd3/mmdit.py

@@ -492,7 +492,6 @@ class MMDiT(nn.Module):
         device = None,
         device = None,
     ):
     ):
         super().__init__()
         super().__init__()
-        print(f"mmdit initializing with: {input_size=}, {patch_size=}, {in_channels=}, {depth=}, {mlp_ratio=}, {learn_sigma=}, {adm_in_channels=}, {context_embedder_config=}, {register_length=}, {attn_mode=}, {rmsnorm=}, {scale_mod_only=}, {swiglu=}, {out_channels=}, {pos_embed_scaling_factor=}, {pos_embed_offset=}, {pos_embed_max_size=}, {num_patches=}, {qk_norm=}, {qkv_bias=}, {dtype=}, {device=}")
         self.dtype = dtype
         self.dtype = dtype
         self.learn_sigma = learn_sigma
         self.learn_sigma = learn_sigma
         self.in_channels = in_channels
         self.in_channels = in_channels

+ 11 - 1
modules/models/sd3/sd3_model.py

@@ -120,6 +120,9 @@ class SD3Cond(torch.nn.Module):
     def encode_embedding_init_text(self, init_text, nvpt):
     def encode_embedding_init_text(self, init_text, nvpt):
         return torch.tensor([[0]], device=devices.device) # XXX
         return torch.tensor([[0]], device=devices.device) # XXX
 
 
+    def medvram_modules(self):
+        return [self.clip_g, self.clip_l, self.t5xxl]
+
 
 
 class SD3Denoiser(k_diffusion.external.DiscreteSchedule):
 class SD3Denoiser(k_diffusion.external.DiscreteSchedule):
     def __init__(self, inner_model, sigmas):
     def __init__(self, inner_model, sigmas):
@@ -163,7 +166,7 @@ class SD3Inferencer(torch.nn.Module):
             return self.cond_stage_model(batch)
             return self.cond_stage_model(batch)
 
 
     def apply_model(self, x, t, cond):
     def apply_model(self, x, t, cond):
-        return self.model.apply_model(x, t, c_crossattn=cond['crossattn'], y=cond['vector'])
+        return self.model(x, t, c_crossattn=cond['crossattn'], y=cond['vector'])
 
 
     def decode_first_stage(self, latent):
     def decode_first_stage(self, latent):
         latent = self.latent_format.process_out(latent)
         latent = self.latent_format.process_out(latent)
@@ -175,3 +178,10 @@ class SD3Inferencer(torch.nn.Module):
 
 
     def create_denoiser(self):
     def create_denoiser(self):
         return SD3Denoiser(self, self.model.model_sampling.sigmas)
         return SD3Denoiser(self, self.model.model_sampling.sigmas)
+
+    def medvram_fields(self):
+        return [
+            (self, 'first_stage_model'),
+            (self, 'cond_stage_model'),
+            (self, 'model'),
+        ]

+ 1 - 1
modules/sd_samplers_common.py

@@ -163,7 +163,7 @@ def apply_refiner(cfg_denoiser, sigma=None):
     else:
     else:
         # torch.max(sigma) only to handle rare case where we might have different sigmas in the same batch
         # torch.max(sigma) only to handle rare case where we might have different sigmas in the same batch
         try:
         try:
-            timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas - torch.max(sigma)))
+            timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas.to(sigma.device) - torch.max(sigma)))
         except AttributeError:  # for samplers that don't use sigmas (DDIM) sigma is actually the timestep
         except AttributeError:  # for samplers that don't use sigmas (DDIM) sigma is actually the timestep
             timestep = torch.max(sigma).to(dtype=int)
             timestep = torch.max(sigma).to(dtype=int)
         completed_ratio = (999 - timestep) / 1000
         completed_ratio = (999 - timestep) / 1000