|
@@ -1,9 +1,12 @@
|
|
|
+from collections import namedtuple
|
|
|
+
|
|
|
import torch
|
|
|
from modules import devices, shared
|
|
|
|
|
|
module_in_gpu = None
|
|
|
cpu = torch.device("cpu")
|
|
|
|
|
|
+ModuleWithParent = namedtuple('ModuleWithParent', ['module', 'parent'], defaults=['None'])
|
|
|
|
|
|
def send_everything_to_cpu():
|
|
|
global module_in_gpu
|
|
@@ -75,13 +78,14 @@ def setup_for_low_vram(sd_model, use_medvram):
|
|
|
(sd_model, 'depth_model'),
|
|
|
(sd_model, 'embedder'),
|
|
|
(sd_model, 'model'),
|
|
|
- (sd_model, 'embedder'),
|
|
|
]
|
|
|
|
|
|
is_sdxl = hasattr(sd_model, 'conditioner')
|
|
|
is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
|
|
|
|
|
|
- if is_sdxl:
|
|
|
+ if hasattr(sd_model, 'medvram_fields'):
|
|
|
+ to_remain_in_cpu = sd_model.medvram_fields()
|
|
|
+ elif is_sdxl:
|
|
|
to_remain_in_cpu.append((sd_model, 'conditioner'))
|
|
|
elif is_sd2:
|
|
|
to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
|
|
@@ -103,7 +107,21 @@ def setup_for_low_vram(sd_model, use_medvram):
|
|
|
setattr(obj, field, module)
|
|
|
|
|
|
# register hooks for those the first three models
|
|
|
- if is_sdxl:
|
|
|
+ if hasattr(sd_model.cond_stage_model, "medvram_modules"):
|
|
|
+ for module in sd_model.cond_stage_model.medvram_modules():
|
|
|
+ if isinstance(module, ModuleWithParent):
|
|
|
+ parent = module.parent
|
|
|
+ module = module.module
|
|
|
+ else:
|
|
|
+ parent = None
|
|
|
+
|
|
|
+ if module:
|
|
|
+ module.register_forward_pre_hook(send_me_to_gpu)
|
|
|
+
|
|
|
+ if parent:
|
|
|
+ parents[module] = parent
|
|
|
+
|
|
|
+ elif is_sdxl:
|
|
|
sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
|
|
|
elif is_sd2:
|
|
|
sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
|
|
@@ -117,9 +135,9 @@ def setup_for_low_vram(sd_model, use_medvram):
|
|
|
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
|
|
|
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
|
|
|
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
|
|
|
- if sd_model.depth_model:
|
|
|
+ if hasattr(sd_model, 'depth_model'):
|
|
|
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
|
|
|
- if sd_model.embedder:
|
|
|
+ if hasattr(sd_model, 'embedder'):
|
|
|
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
|
|
|
|
|
|
if use_medvram:
|