|
@@ -353,17 +353,6 @@ def load_hypernetworks(names, multipliers=None):
|
|
shared.loaded_hypernetworks.append(hypernetwork)
|
|
shared.loaded_hypernetworks.append(hypernetwork)
|
|
|
|
|
|
|
|
|
|
-def find_closest_hypernetwork_name(search: str):
|
|
|
|
- if not search:
|
|
|
|
- return None
|
|
|
|
- search = search.lower()
|
|
|
|
- applicable = [name for name in shared.hypernetworks if search in name.lower()]
|
|
|
|
- if not applicable:
|
|
|
|
- return None
|
|
|
|
- applicable = sorted(applicable, key=lambda name: len(name))
|
|
|
|
- return applicable[0]
|
|
|
|
-
|
|
|
|
-
|
|
|
|
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
|
|
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
|
|
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
|
|
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
|
|
|
|
|
|
@@ -446,18 +435,6 @@ def statistics(data):
|
|
return total_information, recent_information
|
|
return total_information, recent_information
|
|
|
|
|
|
|
|
|
|
-def report_statistics(loss_info:dict):
|
|
|
|
- keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
|
|
|
|
- for key in keys:
|
|
|
|
- try:
|
|
|
|
- print("Loss statistics for file " + key)
|
|
|
|
- info, recent = statistics(list(loss_info[key]))
|
|
|
|
- print(info)
|
|
|
|
- print(recent)
|
|
|
|
- except Exception as e:
|
|
|
|
- print(e)
|
|
|
|
-
|
|
|
|
-
|
|
|
|
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
|
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
|
# Remove illegal characters from name.
|
|
# Remove illegal characters from name.
|
|
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
|
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
|
@@ -770,7 +747,6 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|
pbar.leave = False
|
|
pbar.leave = False
|
|
pbar.close()
|
|
pbar.close()
|
|
hypernetwork.eval()
|
|
hypernetwork.eval()
|
|
- #report_statistics(loss_dict)
|
|
|
|
sd_hijack_checkpoint.remove()
|
|
sd_hijack_checkpoint.remove()
|
|
|
|
|
|
|
|
|