瀏覽代碼

repair --medvram for SD2.x too after SDXL update

AUTOMATIC1111 2 年之前
父節點
當前提交
2c11e9009e
共有 2 個文件被更改,包括 5 次插入4 次删除
  1. 4 3
      modules/lowvram.py
  2. 1 1
      modules/sd_hijack_open_clip.py

+ 4 - 3
modules/lowvram.py

@@ -90,8 +90,12 @@ def setup_for_low_vram(sd_model, use_medvram):
         sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
     elif is_sd2:
         sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
+        sd_model.cond_stage_model.model.token_embedding.register_forward_pre_hook(send_me_to_gpu)
+        parents[sd_model.cond_stage_model.model] = sd_model.cond_stage_model
+        parents[sd_model.cond_stage_model.model.token_embedding] = sd_model.cond_stage_model
     else:
         sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
+        parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
 
     sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
     sd_model.first_stage_model.encode = first_stage_model_encode_wrap
@@ -101,9 +105,6 @@ def setup_for_low_vram(sd_model, use_medvram):
     if sd_model.embedder:
         sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
 
-    if hasattr(sd_model, 'cond_stage_model'):
-        parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
-
     if use_medvram:
         sd_model.model.register_forward_pre_hook(send_me_to_gpu)
     else:

+ 1 - 1
modules/sd_hijack_open_clip.py

@@ -32,7 +32,7 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit
     def encode_embedding_init_text(self, init_text, nvpt):
         ids = tokenizer.encode(init_text)
         ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
-        embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0)
+        embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
 
         return embedded