|
@@ -90,8 +90,12 @@ def setup_for_low_vram(sd_model, use_medvram):
|
|
|
sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
|
|
|
elif is_sd2:
|
|
|
sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
|
|
|
+ sd_model.cond_stage_model.model.token_embedding.register_forward_pre_hook(send_me_to_gpu)
|
|
|
+ parents[sd_model.cond_stage_model.model] = sd_model.cond_stage_model
|
|
|
+ parents[sd_model.cond_stage_model.model.token_embedding] = sd_model.cond_stage_model
|
|
|
else:
|
|
|
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
|
|
|
+ parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
|
|
|
|
|
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
|
|
|
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
|
|
@@ -101,9 +105,6 @@ def setup_for_low_vram(sd_model, use_medvram):
|
|
|
if sd_model.embedder:
|
|
|
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
|
|
|
|
|
|
- if hasattr(sd_model, 'cond_stage_model'):
|
|
|
- parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
|
|
-
|
|
|
if use_medvram:
|
|
|
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
|
|
|
else:
|