|
@@ -783,7 +783,7 @@ def get_obj_from_str(string, reload=False):
|
|
return getattr(importlib.import_module(module, package=None), cls)
|
|
return getattr(importlib.import_module(module, package=None), cls)
|
|
|
|
|
|
|
|
|
|
-def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|
|
|
|
|
+def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_config=None):
|
|
from modules import sd_hijack
|
|
from modules import sd_hijack
|
|
checkpoint_info = checkpoint_info or select_checkpoint()
|
|
checkpoint_info = checkpoint_info or select_checkpoint()
|
|
|
|
|
|
@@ -801,7 +801,8 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|
else:
|
|
else:
|
|
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
|
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
|
|
|
|
|
- checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
|
|
|
|
|
+ if not checkpoint_config:
|
|
|
|
+ checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
|
clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)
|
|
clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)
|
|
|
|
|
|
timer.record("find config")
|
|
timer.record("find config")
|
|
@@ -974,7 +975,7 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False):
|
|
if sd_model is not None:
|
|
if sd_model is not None:
|
|
send_model_to_trash(sd_model)
|
|
send_model_to_trash(sd_model)
|
|
|
|
|
|
- load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
|
|
|
|
|
+ load_model(checkpoint_info, already_loaded_state_dict=state_dict, checkpoint_config=checkpoint_config)
|
|
return model_data.sd_model
|
|
return model_data.sd_model
|
|
|
|
|
|
try:
|
|
try:
|