|
@@ -517,7 +517,7 @@ def get_empty_cond(sd_model):
|
|
|
|
|
|
|
|
|
def send_model_to_cpu(m):
|
|
|
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
|
|
+ if m.lowvram:
|
|
|
lowvram.send_everything_to_cpu()
|
|
|
else:
|
|
|
m.to(devices.cpu)
|
|
@@ -525,17 +525,17 @@ def send_model_to_cpu(m):
|
|
|
devices.torch_gc()
|
|
|
|
|
|
|
|
|
-def model_target_device():
|
|
|
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
|
|
+def model_target_device(m):
|
|
|
+ if lowvram.is_needed(m):
|
|
|
return devices.cpu
|
|
|
else:
|
|
|
return devices.device
|
|
|
|
|
|
|
|
|
def send_model_to_device(m):
|
|
|
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
|
|
- lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
|
|
|
- else:
|
|
|
+ lowvram.apply(m)
|
|
|
+
|
|
|
+ if not m.lowvram:
|
|
|
m.to(shared.device)
|
|
|
|
|
|
|
|
@@ -601,7 +601,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|
|
'': torch.float16,
|
|
|
}
|
|
|
|
|
|
- with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(), weight_dtype_conversion=weight_dtype_conversion):
|
|
|
+ with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
|
|
|
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
|
|
timer.record("load weights from state dict")
|
|
|
|
|
@@ -743,7 +743,7 @@ def reload_model_weights(sd_model=None, info=None):
|
|
|
script_callbacks.model_loaded_callback(sd_model)
|
|
|
timer.record("script callbacks")
|
|
|
|
|
|
- if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
|
|
+ if not sd_model.lowvram:
|
|
|
sd_model.to(devices.device)
|
|
|
timer.record("move model to device")
|
|
|
|