|
@@ -339,10 +339,28 @@ class SkipWritingToConfig:
|
|
|
SkipWritingToConfig.skip = self.previous
|
|
|
|
|
|
|
|
|
+def check_fp8(model):
|
|
|
+ if model is None:
|
|
|
+ return None
|
|
|
+ if devices.get_optimal_device_name() == "mps":
|
|
|
+ enable_fp8 = False
|
|
|
+ elif shared.opts.fp8_storage == "Enable":
|
|
|
+ enable_fp8 = True
|
|
|
+ elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL":
|
|
|
+ enable_fp8 = True
|
|
|
+ else:
|
|
|
+ enable_fp8 = False
|
|
|
+ return enable_fp8
|
|
|
+
|
|
|
+
|
|
|
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
|
|
sd_model_hash = checkpoint_info.calculate_shorthash()
|
|
|
timer.record("calculate hash")
|
|
|
|
|
|
+ if not check_fp8(model) and devices.fp8:
|
|
|
+ # prevent model to load state dict in fp8
|
|
|
+ model.half()
|
|
|
+
|
|
|
if not SkipWritingToConfig.skip:
|
|
|
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
|
|
|
|
@@ -395,34 +413,16 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|
|
devices.dtype_unet = torch.float16
|
|
|
timer.record("apply half()")
|
|
|
|
|
|
- if devices.get_optimal_device_name() == "mps":
|
|
|
- enable_fp8 = False
|
|
|
- elif shared.cmd_opts.opt_unet_fp8_storage:
|
|
|
- enable_fp8 = True
|
|
|
- elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl:
|
|
|
- enable_fp8 = True
|
|
|
- else:
|
|
|
- enable_fp8 = False
|
|
|
-
|
|
|
- if enable_fp8:
|
|
|
+ if check_fp8(model):
|
|
|
devices.fp8 = True
|
|
|
- if model.is_sdxl:
|
|
|
- cond_stage = model.conditioner
|
|
|
- else:
|
|
|
- cond_stage = model.cond_stage_model
|
|
|
-
|
|
|
- for module in cond_stage.modules():
|
|
|
- if isinstance(module, torch.nn.Linear):
|
|
|
+ first_stage = model.first_stage_model
|
|
|
+ model.first_stage_model = None
|
|
|
+ for module in model.modules():
|
|
|
+ if isinstance(module, torch.nn.Conv2d):
|
|
|
module.to(torch.float8_e4m3fn)
|
|
|
-
|
|
|
- if devices.device == devices.cpu:
|
|
|
- for module in model.model.diffusion_model.modules():
|
|
|
- if isinstance(module, torch.nn.Conv2d):
|
|
|
- module.to(torch.float8_e4m3fn)
|
|
|
- elif isinstance(module, torch.nn.Linear):
|
|
|
- module.to(torch.float8_e4m3fn)
|
|
|
- else:
|
|
|
- model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn)
|
|
|
+ elif isinstance(module, torch.nn.Linear):
|
|
|
+ module.to(torch.float8_e4m3fn)
|
|
|
+ model.first_stage_model = first_stage
|
|
|
timer.record("apply fp8")
|
|
|
else:
|
|
|
devices.fp8 = False
|
|
@@ -769,7 +769,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
|
|
|
return None
|
|
|
|
|
|
|
|
|
-def reload_model_weights(sd_model=None, info=None):
|
|
|
+def reload_model_weights(sd_model=None, info=None, forced_reload=False):
|
|
|
checkpoint_info = info or select_checkpoint()
|
|
|
|
|
|
timer = Timer()
|
|
@@ -781,11 +781,14 @@ def reload_model_weights(sd_model=None, info=None):
|
|
|
current_checkpoint_info = None
|
|
|
else:
|
|
|
current_checkpoint_info = sd_model.sd_checkpoint_info
|
|
|
- if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
|
|
+ if check_fp8(sd_model) != devices.fp8:
|
|
|
+ # load from state dict again to prevent extra numerical errors
|
|
|
+ forced_reload = True
|
|
|
+ elif sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
|
|
return sd_model
|
|
|
|
|
|
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
|
|
|
- if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
|
|
+ if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
|
|
return sd_model
|
|
|
|
|
|
if sd_model is not None:
|