Преглед на файлове

SD2 v autodetection fix

AUTOMATIC1111 преди 1 година
родител
ревизия
74069addc3
променени са 1 файла, в които са добавени 3 реда и са изтрити 2 реда
  1. 3 2
      modules/sd_models_config.py

+ 3 - 2
modules/sd_models_config.py

@@ -58,12 +58,13 @@ def is_using_v_parameterization_for_sd2(state_dict):
     with torch.no_grad():
         unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
         unet.load_state_dict(unet_sd, strict=True)
-        unet.to(device=device, dtype=torch.float)
+        unet.to(device=device, dtype=devices.dtype_unet)
 
         test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
         x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
 
-        out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item()
+        with devices.autocast():
+            out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().cpu().item()
 
     return out < -1