|
@@ -58,12 +58,13 @@ def is_using_v_parameterization_for_sd2(state_dict):
|
|
with torch.no_grad():
|
|
with torch.no_grad():
|
|
unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
|
|
unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
|
|
unet.load_state_dict(unet_sd, strict=True)
|
|
unet.load_state_dict(unet_sd, strict=True)
|
|
- unet.to(device=device, dtype=torch.float)
|
|
|
|
|
|
+ unet.to(device=device, dtype=devices.dtype_unet)
|
|
|
|
|
|
test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
|
|
test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
|
|
x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
|
|
x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
|
|
|
|
|
|
- out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item()
|
|
|
|
|
|
+ with devices.autocast():
|
|
|
|
+ out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().cpu().item()
|
|
|
|
|
|
return out < -1
|
|
return out < -1
|
|
|
|
|