|
@@ -159,48 +159,61 @@ def run_pnginfo(image):
|
|
|
return '', geninfo, info
|
|
|
|
|
|
|
|
|
-def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount, save_as_half, custom_name):
|
|
|
+def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, interp_amount, save_as_half, custom_name):
|
|
|
# Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation)
|
|
|
- def weighted_sum(theta0, theta1, alpha):
|
|
|
+ def weighted_sum(theta0, theta1, theta2, alpha):
|
|
|
return ((1 - alpha) * theta0) + (alpha * theta1)
|
|
|
|
|
|
# Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
|
|
|
- def sigmoid(theta0, theta1, alpha):
|
|
|
+ def sigmoid(theta0, theta1, theta2, alpha):
|
|
|
alpha = alpha * alpha * (3 - (2 * alpha))
|
|
|
return theta0 + ((theta1 - theta0) * alpha)
|
|
|
|
|
|
# Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
|
|
|
- def inv_sigmoid(theta0, theta1, alpha):
|
|
|
+ def inv_sigmoid(theta0, theta1, theta2, alpha):
|
|
|
import math
|
|
|
alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0)
|
|
|
return theta0 + ((theta1 - theta0) * alpha)
|
|
|
|
|
|
+ def add_difference(theta0, theta1, theta2, alpha):
|
|
|
+ return theta0 + (theta1 - theta2) * (1.0 - alpha)
|
|
|
+
|
|
|
primary_model_info = sd_models.checkpoints_list[primary_model_name]
|
|
|
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
|
|
|
+ teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None)
|
|
|
|
|
|
print(f"Loading {primary_model_info.filename}...")
|
|
|
primary_model = torch.load(primary_model_info.filename, map_location='cpu')
|
|
|
+ theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
|
|
|
|
|
|
print(f"Loading {secondary_model_info.filename}...")
|
|
|
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
|
|
|
-
|
|
|
- theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
|
|
|
theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
|
|
|
|
|
|
+ if teritary_model_info is not None:
|
|
|
+ print(f"Loading {teritary_model_info.filename}...")
|
|
|
+ teritary_model = torch.load(teritary_model_info.filename, map_location='cpu')
|
|
|
+ theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model)
|
|
|
+ else:
|
|
|
+ theta_2 = None
|
|
|
+
|
|
|
theta_funcs = {
|
|
|
"Weighted Sum": weighted_sum,
|
|
|
"Sigmoid": sigmoid,
|
|
|
"Inverse Sigmoid": inv_sigmoid,
|
|
|
+ "Add difference": add_difference,
|
|
|
}
|
|
|
theta_func = theta_funcs[interp_method]
|
|
|
|
|
|
print(f"Merging...")
|
|
|
+
|
|
|
for key in tqdm.tqdm(theta_0.keys()):
|
|
|
if 'model' in key and key in theta_1:
|
|
|
- theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
|
|
|
+ theta_0[key] = theta_func(theta_0[key], theta_1[key], theta_2[key] if theta_2 else None, (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
|
|
|
if save_as_half:
|
|
|
theta_0[key] = theta_0[key].half()
|
|
|
|
|
|
+ # I believe this part should be discarded, but I'll leave it for now until I am sure
|
|
|
for key in theta_1.keys():
|
|
|
if 'model' in key and key not in theta_0:
|
|
|
theta_0[key] = theta_1[key]
|
|
@@ -219,4 +232,4 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
|
|
|
sd_models.list_models()
|
|
|
|
|
|
print(f"Checkpoint saved.")
|
|
|
- return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(3)]
|
|
|
+ return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
|