Browse Source

Automatically enable ztSNR based on existence of key in state_dict

catboxanon 10 months ago
parent
commit
c2ce1d3b9c
1 changed files with 5 additions and 1 deletions
  1. 5 1
      modules/sd_models.py

+ 5 - 1
modules/sd_models.py

@@ -423,6 +423,10 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
 
     set_model_type(model, state_dict)
     set_model_fields(model)
+    if 'ztsnr' in state_dict:
+        model.ztsnr = True
+    else:
+        model.ztsnr = False
 
     if model.is_sdxl:
         sd_models_xl.extend_sdxl(model)
@@ -661,7 +665,7 @@ def apply_alpha_schedule_override(sd_model, p=None):
             p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
         sd_model.alphas_cumprod = sd_model.alphas_cumprod.half().to(shared.device)
 
-    if opts.sd_noise_schedule == "Zero Terminal SNR":
+    if opts.sd_noise_schedule == "Zero Terminal SNR" or (hasattr(sd_model, 'ztsnr') and sd_model.ztsnr):
         if p is not None:
             p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
         sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device)