|
@@ -39,7 +39,7 @@ class HypernetworkModule(torch.nn.Module):
|
|
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
|
|
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
|
|
|
|
|
|
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
|
|
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
|
|
- add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False):
|
|
|
|
|
|
+ add_layer_norm=False, activate_output=False, dropout_structure=None):
|
|
super().__init__()
|
|
super().__init__()
|
|
|
|
|
|
assert layer_structure is not None, "layer_structure must not be None"
|
|
assert layer_structure is not None, "layer_structure must not be None"
|
|
@@ -64,9 +64,12 @@ class HypernetworkModule(torch.nn.Module):
|
|
if add_layer_norm:
|
|
if add_layer_norm:
|
|
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
|
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
|
|
|
|
|
- # Add dropout except last layer
|
|
|
|
- if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2):
|
|
|
|
- linears.append(torch.nn.Dropout(p=0.3))
|
|
|
|
|
|
+ # Everything should be now parsed into dropout structure, and applied here.
|
|
|
|
+ # Since we only have dropouts after layers, dropout structure should start with 0 and end with 0.
|
|
|
|
+ if dropout_structure is not None and dropout_structure[i+1] > 0:
|
|
|
|
+ assert 0 < dropout_structure[i+1] < 1, "Dropout probability should be 0 or float between 0 and 1!"
|
|
|
|
+ linears.append(torch.nn.Dropout(p=dropout_structure[i+1]))
|
|
|
|
+ # Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0].
|
|
|
|
|
|
self.linear = torch.nn.Sequential(*linears)
|
|
self.linear = torch.nn.Sequential(*linears)
|
|
|
|
|
|
@@ -113,7 +116,7 @@ class HypernetworkModule(torch.nn.Module):
|
|
state_dict[to] = x
|
|
state_dict[to] = x
|
|
|
|
|
|
def forward(self, x):
|
|
def forward(self, x):
|
|
- return x + self.linear(x) * self.multiplier
|
|
|
|
|
|
+ return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1)
|
|
|
|
|
|
def trainables(self):
|
|
def trainables(self):
|
|
layer_structure = []
|
|
layer_structure = []
|
|
@@ -126,6 +129,21 @@ class HypernetworkModule(torch.nn.Module):
|
|
def apply_strength(value=None):
|
|
def apply_strength(value=None):
|
|
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
|
|
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
|
|
|
|
|
|
|
|
+#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
|
|
|
|
+def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
|
|
|
|
+ if layer_structure is None:
|
|
|
|
+ layer_structure = [1, 2, 1]
|
|
|
|
+ if not use_dropout:
|
|
|
|
+ return [0] * len(layer_structure)
|
|
|
|
+ dropout_values = [0]
|
|
|
|
+ dropout_values.extend([0.3] * (len(layer_structure) - 3))
|
|
|
|
+ if last_layer_dropout:
|
|
|
|
+ dropout_values.append(0.3)
|
|
|
|
+ else:
|
|
|
|
+ dropout_values.append(0)
|
|
|
|
+ dropout_values.append(0)
|
|
|
|
+ return dropout_values
|
|
|
|
+
|
|
|
|
|
|
class Hypernetwork:
|
|
class Hypernetwork:
|
|
filename = None
|
|
filename = None
|
|
@@ -144,18 +162,22 @@ class Hypernetwork:
|
|
self.add_layer_norm = add_layer_norm
|
|
self.add_layer_norm = add_layer_norm
|
|
self.use_dropout = use_dropout
|
|
self.use_dropout = use_dropout
|
|
self.activate_output = activate_output
|
|
self.activate_output = activate_output
|
|
- self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True
|
|
|
|
|
|
+ self.last_layer_dropout = kwargs.get('last_layer_dropout', True)
|
|
|
|
+ self.dropout_structure = kwargs.get('dropout_structure', None)
|
|
|
|
+ if self.dropout_structure is None:
|
|
|
|
+ self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
|
|
self.optimizer_name = None
|
|
self.optimizer_name = None
|
|
self.optimizer_state_dict = None
|
|
self.optimizer_state_dict = None
|
|
|
|
+ self.optional_info = None
|
|
|
|
|
|
for size in enable_sizes or []:
|
|
for size in enable_sizes or []:
|
|
self.layers[size] = (
|
|
self.layers[size] = (
|
|
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
|
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
|
- self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
|
|
|
|
|
|
+ self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
|
|
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
|
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
|
- self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
|
|
|
|
|
|
+ self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
|
|
)
|
|
)
|
|
- self.eval_mode()
|
|
|
|
|
|
+ self.eval()
|
|
|
|
|
|
def weights(self):
|
|
def weights(self):
|
|
res = []
|
|
res = []
|
|
@@ -164,14 +186,14 @@ class Hypernetwork:
|
|
res += layer.parameters()
|
|
res += layer.parameters()
|
|
return res
|
|
return res
|
|
|
|
|
|
- def train_mode(self):
|
|
|
|
|
|
+ def train(self, mode=True):
|
|
for k, layers in self.layers.items():
|
|
for k, layers in self.layers.items():
|
|
for layer in layers:
|
|
for layer in layers:
|
|
- layer.train()
|
|
|
|
|
|
+ layer.train(mode=mode)
|
|
for param in layer.parameters():
|
|
for param in layer.parameters():
|
|
- param.requires_grad = True
|
|
|
|
|
|
+ param.requires_grad = mode
|
|
|
|
|
|
- def eval_mode(self):
|
|
|
|
|
|
+ def eval(self):
|
|
for k, layers in self.layers.items():
|
|
for k, layers in self.layers.items():
|
|
for layer in layers:
|
|
for layer in layers:
|
|
layer.eval()
|
|
layer.eval()
|
|
@@ -191,11 +213,13 @@ class Hypernetwork:
|
|
state_dict['activation_func'] = self.activation_func
|
|
state_dict['activation_func'] = self.activation_func
|
|
state_dict['is_layer_norm'] = self.add_layer_norm
|
|
state_dict['is_layer_norm'] = self.add_layer_norm
|
|
state_dict['weight_initialization'] = self.weight_init
|
|
state_dict['weight_initialization'] = self.weight_init
|
|
- state_dict['use_dropout'] = self.use_dropout
|
|
|
|
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
|
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
|
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
|
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
|
state_dict['activate_output'] = self.activate_output
|
|
state_dict['activate_output'] = self.activate_output
|
|
- state_dict['last_layer_dropout'] = self.last_layer_dropout
|
|
|
|
|
|
+ state_dict['use_dropout'] = self.use_dropout
|
|
|
|
+ state_dict['dropout_structure'] = self.dropout_structure
|
|
|
|
+ state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout
|
|
|
|
+ state_dict['optional_info'] = self.optional_info if self.optional_info else None
|
|
|
|
|
|
if self.optimizer_name is not None:
|
|
if self.optimizer_name is not None:
|
|
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
|
|
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
|
|
@@ -215,43 +239,56 @@ class Hypernetwork:
|
|
|
|
|
|
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
|
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
|
print(self.layer_structure)
|
|
print(self.layer_structure)
|
|
|
|
+ optional_info = state_dict.get('optional_info', None)
|
|
|
|
+ if optional_info is not None:
|
|
|
|
+ print(f"INFO:\n {optional_info}\n")
|
|
|
|
+ self.optional_info = optional_info
|
|
self.activation_func = state_dict.get('activation_func', None)
|
|
self.activation_func = state_dict.get('activation_func', None)
|
|
print(f"Activation function is {self.activation_func}")
|
|
print(f"Activation function is {self.activation_func}")
|
|
self.weight_init = state_dict.get('weight_initialization', 'Normal')
|
|
self.weight_init = state_dict.get('weight_initialization', 'Normal')
|
|
print(f"Weight initialization is {self.weight_init}")
|
|
print(f"Weight initialization is {self.weight_init}")
|
|
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
|
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
|
print(f"Layer norm is set to {self.add_layer_norm}")
|
|
print(f"Layer norm is set to {self.add_layer_norm}")
|
|
- self.use_dropout = state_dict.get('use_dropout', False)
|
|
|
|
|
|
+ self.dropout_structure = state_dict.get('dropout_structure', None)
|
|
|
|
+ self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
|
|
print(f"Dropout usage is set to {self.use_dropout}" )
|
|
print(f"Dropout usage is set to {self.use_dropout}" )
|
|
self.activate_output = state_dict.get('activate_output', True)
|
|
self.activate_output = state_dict.get('activate_output', True)
|
|
print(f"Activate last layer is set to {self.activate_output}")
|
|
print(f"Activate last layer is set to {self.activate_output}")
|
|
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
|
|
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
|
|
|
|
+ # Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
|
|
|
|
+ if self.dropout_structure is None:
|
|
|
|
+ print("Using previous dropout structure")
|
|
|
|
+ self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
|
|
|
|
+ print(f"Dropout structure is set to {self.dropout_structure}")
|
|
|
|
|
|
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
|
|
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
|
|
- self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
|
|
|
|
- print(f"Optimizer name is {self.optimizer_name}")
|
|
|
|
|
|
+
|
|
if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
|
|
if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
|
|
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
|
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
|
else:
|
|
else:
|
|
self.optimizer_state_dict = None
|
|
self.optimizer_state_dict = None
|
|
if self.optimizer_state_dict:
|
|
if self.optimizer_state_dict:
|
|
|
|
+ self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
|
|
print("Loaded existing optimizer from checkpoint")
|
|
print("Loaded existing optimizer from checkpoint")
|
|
|
|
+ print(f"Optimizer name is {self.optimizer_name}")
|
|
else:
|
|
else:
|
|
|
|
+ self.optimizer_name = "AdamW"
|
|
print("No saved optimizer exists in checkpoint")
|
|
print("No saved optimizer exists in checkpoint")
|
|
|
|
|
|
for size, sd in state_dict.items():
|
|
for size, sd in state_dict.items():
|
|
if type(size) == int:
|
|
if type(size) == int:
|
|
self.layers[size] = (
|
|
self.layers[size] = (
|
|
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
|
|
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
|
|
- self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
|
|
|
|
|
|
+ self.add_layer_norm, self.activate_output, self.dropout_structure),
|
|
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
|
|
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
|
|
- self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
|
|
|
|
|
|
+ self.add_layer_norm, self.activate_output, self.dropout_structure),
|
|
)
|
|
)
|
|
|
|
|
|
self.name = state_dict.get('name', self.name)
|
|
self.name = state_dict.get('name', self.name)
|
|
self.step = state_dict.get('step', 0)
|
|
self.step = state_dict.get('step', 0)
|
|
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
|
|
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
|
|
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
|
|
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
|
|
|
|
+ self.eval()
|
|
|
|
|
|
|
|
|
|
def list_hypernetworks(path):
|
|
def list_hypernetworks(path):
|
|
@@ -379,9 +416,10 @@ def report_statistics(loss_info:dict):
|
|
print(e)
|
|
print(e)
|
|
|
|
|
|
|
|
|
|
-def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
|
|
|
|
|
|
+def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
|
# Remove illegal characters from name.
|
|
# Remove illegal characters from name.
|
|
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
|
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
|
|
|
+ assert name, "Name cannot be empty!"
|
|
|
|
|
|
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
|
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
|
if not overwrite_old:
|
|
if not overwrite_old:
|
|
@@ -390,6 +428,11 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
|
|
if type(layer_structure) == str:
|
|
if type(layer_structure) == str:
|
|
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
|
|
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
|
|
|
|
|
|
|
|
+ if use_dropout and dropout_structure and type(dropout_structure) == str:
|
|
|
|
+ dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")]
|
|
|
|
+ else:
|
|
|
|
+ dropout_structure = [0] * len(layer_structure)
|
|
|
|
+
|
|
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
|
|
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
|
|
name=name,
|
|
name=name,
|
|
enable_sizes=[int(x) for x in enable_sizes],
|
|
enable_sizes=[int(x) for x in enable_sizes],
|
|
@@ -398,6 +441,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
|
|
weight_init=weight_init,
|
|
weight_init=weight_init,
|
|
add_layer_norm=add_layer_norm,
|
|
add_layer_norm=add_layer_norm,
|
|
use_dropout=use_dropout,
|
|
use_dropout=use_dropout,
|
|
|
|
+ dropout_structure=dropout_structure
|
|
)
|
|
)
|
|
hypernet.save(fn)
|
|
hypernet.save(fn)
|
|
|
|
|
|
@@ -480,7 +524,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
|
|
|
|
|
weights = hypernetwork.weights()
|
|
weights = hypernetwork.weights()
|
|
- hypernetwork.train_mode()
|
|
|
|
|
|
+ hypernetwork.train()
|
|
|
|
|
|
# Here we use optimizer from saved HN, or we can specify as UI option.
|
|
# Here we use optimizer from saved HN, or we can specify as UI option.
|
|
if hypernetwork.optimizer_name in optimizer_dict:
|
|
if hypernetwork.optimizer_name in optimizer_dict:
|
|
@@ -594,7 +638,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|
if images_dir is not None and steps_done % create_image_every == 0:
|
|
if images_dir is not None and steps_done % create_image_every == 0:
|
|
forced_filename = f'{hypernetwork_name}-{steps_done}'
|
|
forced_filename = f'{hypernetwork_name}-{steps_done}'
|
|
last_saved_image = os.path.join(images_dir, forced_filename)
|
|
last_saved_image = os.path.join(images_dir, forced_filename)
|
|
- hypernetwork.eval_mode()
|
|
|
|
|
|
+ hypernetwork.eval()
|
|
|
|
+ rng_state = torch.get_rng_state()
|
|
|
|
+ cuda_rng_state = None
|
|
|
|
+ if torch.cuda.is_available():
|
|
|
|
+ cuda_rng_state = torch.cuda.get_rng_state_all()
|
|
shared.sd_model.cond_stage_model.to(devices.device)
|
|
shared.sd_model.cond_stage_model.to(devices.device)
|
|
shared.sd_model.first_stage_model.to(devices.device)
|
|
shared.sd_model.first_stage_model.to(devices.device)
|
|
|
|
|
|
@@ -627,7 +675,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|
if unload:
|
|
if unload:
|
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
|
- hypernetwork.train_mode()
|
|
|
|
|
|
+ torch.set_rng_state(rng_state)
|
|
|
|
+ if torch.cuda.is_available():
|
|
|
|
+ torch.cuda.set_rng_state_all(cuda_rng_state)
|
|
|
|
+ hypernetwork.train()
|
|
if image is not None:
|
|
if image is not None:
|
|
shared.state.current_image = image
|
|
shared.state.current_image = image
|
|
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
|
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
|
@@ -649,7 +700,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|
finally:
|
|
finally:
|
|
pbar.leave = False
|
|
pbar.leave = False
|
|
pbar.close()
|
|
pbar.close()
|
|
- hypernetwork.eval_mode()
|
|
|
|
|
|
+ hypernetwork.eval()
|
|
#report_statistics(loss_dict)
|
|
#report_statistics(loss_dict)
|
|
|
|
|
|
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
|
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|