Prechádzať zdrojové kódy

Execute model_loaded_callback after moving to target device

Nuullll 1 rok pred
rodič
commit
a183de04e3
2 zmenil súbory, kde vykonal 5 pridanie a 4 odobranie
  1. 3 3
      modules/sd_models.py
  2. 2 1
      modules/sd_vae.py

+ 3 - 3
modules/sd_models.py

@@ -842,13 +842,13 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False):
         sd_hijack.model_hijack.hijack(sd_model)
         sd_hijack.model_hijack.hijack(sd_model)
         timer.record("hijack")
         timer.record("hijack")
 
 
-        script_callbacks.model_loaded_callback(sd_model)
-        timer.record("script callbacks")
-
         if not sd_model.lowvram:
         if not sd_model.lowvram:
             sd_model.to(devices.device)
             sd_model.to(devices.device)
             timer.record("move model to device")
             timer.record("move model to device")
 
 
+        script_callbacks.model_loaded_callback(sd_model)
+        timer.record("script callbacks")
+
     print(f"Weights loaded in {timer.summary()}.")
     print(f"Weights loaded in {timer.summary()}.")
 
 
     model_data.set_sd_model(sd_model)
     model_data.set_sd_model(sd_model)

+ 2 - 1
modules/sd_vae.py

@@ -273,10 +273,11 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
     load_vae(sd_model, vae_file, vae_source)
     load_vae(sd_model, vae_file, vae_source)
 
 
     sd_hijack.model_hijack.hijack(sd_model)
     sd_hijack.model_hijack.hijack(sd_model)
-    script_callbacks.model_loaded_callback(sd_model)
 
 
     if not sd_model.lowvram:
     if not sd_model.lowvram:
         sd_model.to(devices.device)
         sd_model.to(devices.device)
 
 
+    script_callbacks.model_loaded_callback(sd_model)
+
     print("VAE weights loaded.")
     print("VAE weights loaded.")
     return sd_model
     return sd_model