Kaynağa Gözat

Variable dropout rate

Implements variable dropout rate from #4549

Fixes hypernetwork multiplier being able to modified during training, also fixes user-errors by setting multiplier value to lower values for training.

Changes function name to match torch.nn.module standard

Fixes RNG reset issue when generating previews by restoring RNG state
aria1th 2 yıl önce
ebeveyn
işleme
a4a5475cfa

+ 76 - 25
modules/hypernetworks/hypernetwork.py

@@ -39,7 +39,7 @@ class HypernetworkModule(torch.nn.Module):
     activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
 
     def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
-                 add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False):
+                 add_layer_norm=False, activate_output=False, dropout_structure=None):
         super().__init__()
 
         assert layer_structure is not None, "layer_structure must not be None"
@@ -64,9 +64,12 @@ class HypernetworkModule(torch.nn.Module):
             if add_layer_norm:
                 linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
 
-            # Add dropout except last layer
-            if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2):
-                linears.append(torch.nn.Dropout(p=0.3))
+            # Everything should be now parsed into dropout structure, and applied here.
+            # Since we only have dropouts after layers, dropout structure should start with 0 and end with 0.
+            if dropout_structure is not None and dropout_structure[i+1] > 0:
+                assert 0 < dropout_structure[i+1] < 1, "Dropout probability should be 0 or float between 0 and 1!"
+                linears.append(torch.nn.Dropout(p=dropout_structure[i+1]))
+            # Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0].
 
         self.linear = torch.nn.Sequential(*linears)
 
@@ -113,7 +116,7 @@ class HypernetworkModule(torch.nn.Module):
             state_dict[to] = x
 
     def forward(self, x):
-        return x + self.linear(x) * self.multiplier
+        return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1)
 
     def trainables(self):
         layer_structure = []
@@ -126,6 +129,21 @@ class HypernetworkModule(torch.nn.Module):
 def apply_strength(value=None):
     HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
 
+#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
+def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
+    if layer_structure is None:
+        layer_structure = [1, 2, 1]
+    if not use_dropout:
+        return [0] * len(layer_structure)
+    dropout_values = [0]
+    dropout_values.extend([0.3] * (len(layer_structure) - 3))
+    if last_layer_dropout:
+        dropout_values.append(0.3)
+    else:
+        dropout_values.append(0)
+    dropout_values.append(0)
+    return dropout_values
+
 
 class Hypernetwork:
     filename = None
@@ -144,18 +162,22 @@ class Hypernetwork:
         self.add_layer_norm = add_layer_norm
         self.use_dropout = use_dropout
         self.activate_output = activate_output
-        self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True
+        self.last_layer_dropout = kwargs.get('last_layer_dropout', True)
+        self.dropout_structure = kwargs.get('dropout_structure', None)
+        if self.dropout_structure is None:
+            self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
         self.optimizer_name = None
         self.optimizer_state_dict = None
+        self.optional_info = None
 
         for size in enable_sizes or []:
             self.layers[size] = (
                 HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
-                                   self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
+                                   self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
                 HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
-                                   self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
+                                   self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
             )
-        self.eval_mode()
+        self.eval()
 
     def weights(self):
         res = []
@@ -164,14 +186,14 @@ class Hypernetwork:
                 res += layer.parameters()
         return res
 
-    def train_mode(self):
+    def train(self, mode=True):
         for k, layers in self.layers.items():
             for layer in layers:
-                layer.train()
+                layer.train(mode=mode)
                 for param in layer.parameters():
-                    param.requires_grad = True
+                    param.requires_grad = mode
 
-    def eval_mode(self):
+    def eval(self):
         for k, layers in self.layers.items():
             for layer in layers:
                 layer.eval()
@@ -191,11 +213,13 @@ class Hypernetwork:
         state_dict['activation_func'] = self.activation_func
         state_dict['is_layer_norm'] = self.add_layer_norm
         state_dict['weight_initialization'] = self.weight_init
-        state_dict['use_dropout'] = self.use_dropout
         state_dict['sd_checkpoint'] = self.sd_checkpoint
         state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
         state_dict['activate_output'] = self.activate_output
-        state_dict['last_layer_dropout'] = self.last_layer_dropout
+        state_dict['use_dropout'] = self.use_dropout
+        state_dict['dropout_structure'] = self.dropout_structure
+        state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout
+        state_dict['optional_info'] = self.optional_info if self.optional_info else None
 
         if self.optimizer_name is not None:
             optimizer_saved_dict['optimizer_name'] = self.optimizer_name
@@ -215,43 +239,56 @@ class Hypernetwork:
 
         self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
         print(self.layer_structure)
+        optional_info = state_dict.get('optional_info', None)
+        if optional_info is not None:
+            print(f"INFO:\n {optional_info}\n")
+            self.optional_info = optional_info
         self.activation_func = state_dict.get('activation_func', None)
         print(f"Activation function is {self.activation_func}")
         self.weight_init = state_dict.get('weight_initialization', 'Normal')
         print(f"Weight initialization is {self.weight_init}")
         self.add_layer_norm = state_dict.get('is_layer_norm', False)
         print(f"Layer norm is set to {self.add_layer_norm}")
-        self.use_dropout = state_dict.get('use_dropout', False)
+        self.dropout_structure = state_dict.get('dropout_structure', None)
+        self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
         print(f"Dropout usage is set to {self.use_dropout}" )
         self.activate_output = state_dict.get('activate_output', True)
         print(f"Activate last layer is set to {self.activate_output}")
         self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
+        # Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
+        if self.dropout_structure is None:
+            print("Using previous dropout structure")
+            self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
+        print(f"Dropout structure is set to {self.dropout_structure}")
 
         optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
-        self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
-        print(f"Optimizer name is {self.optimizer_name}")
+
         if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
             self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
         else:
             self.optimizer_state_dict = None
         if self.optimizer_state_dict:
+            self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
             print("Loaded existing optimizer from checkpoint")
+            print(f"Optimizer name is {self.optimizer_name}")
         else:
+            self.optimizer_name = "AdamW"
             print("No saved optimizer exists in checkpoint")
 
         for size, sd in state_dict.items():
             if type(size) == int:
                 self.layers[size] = (
                     HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
-                                       self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
+                                       self.add_layer_norm, self.activate_output, self.dropout_structure),
                     HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
-                                       self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
+                                       self.add_layer_norm, self.activate_output, self.dropout_structure),
                 )
 
         self.name = state_dict.get('name', self.name)
         self.step = state_dict.get('step', 0)
         self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
         self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
+        self.eval()
 
 
 def list_hypernetworks(path):
@@ -379,9 +416,10 @@ def report_statistics(loss_info:dict):
             print(e)
 
 
-def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
+def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
     # Remove illegal characters from name.
     name = "".join( x for x in name if (x.isalnum() or x in "._- "))
+    assert name, "Name cannot be empty!"
 
     fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
     if not overwrite_old:
@@ -390,6 +428,11 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
     if type(layer_structure) == str:
         layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
 
+    if use_dropout and dropout_structure and type(dropout_structure) == str:
+        dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")]
+    else:
+        dropout_structure = [0] * len(layer_structure)
+
     hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
         name=name,
         enable_sizes=[int(x) for x in enable_sizes],
@@ -398,6 +441,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
         weight_init=weight_init,
         add_layer_norm=add_layer_norm,
         use_dropout=use_dropout,
+        dropout_structure=dropout_structure
     )
     hypernet.save(fn)
 
@@ -480,7 +524,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
         shared.sd_model.first_stage_model.to(devices.cpu)
 
     weights = hypernetwork.weights()
-    hypernetwork.train_mode()
+    hypernetwork.train()
 
     # Here we use optimizer from saved HN, or we can specify as UI option.
     if hypernetwork.optimizer_name in optimizer_dict:
@@ -594,7 +638,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
                 if images_dir is not None and steps_done % create_image_every == 0:
                     forced_filename = f'{hypernetwork_name}-{steps_done}'
                     last_saved_image = os.path.join(images_dir, forced_filename)
-                    hypernetwork.eval_mode()
+                    hypernetwork.eval()
+                    rng_state = torch.get_rng_state()
+                    cuda_rng_state = None
+                    if torch.cuda.is_available():
+                        cuda_rng_state = torch.cuda.get_rng_state_all()
                     shared.sd_model.cond_stage_model.to(devices.device)
                     shared.sd_model.first_stage_model.to(devices.device)
 
@@ -627,7 +675,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
                     if unload:
                         shared.sd_model.cond_stage_model.to(devices.cpu)
                         shared.sd_model.first_stage_model.to(devices.cpu)
-                    hypernetwork.train_mode()
+                    torch.set_rng_state(rng_state)
+                    if torch.cuda.is_available():
+                        torch.cuda.set_rng_state_all(cuda_rng_state)
+                    hypernetwork.train()
                     if image is not None:
                         shared.state.current_image = image
                         last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
@@ -649,7 +700,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
     finally:
         pbar.leave = False
         pbar.close()
-        hypernetwork.eval_mode()
+        hypernetwork.eval()
         #report_statistics(loss_dict)
 
     filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')

+ 2 - 2
modules/hypernetworks/ui.py

@@ -9,8 +9,8 @@ from modules import devices, sd_hijack, shared
 not_available = ["hardswish", "multiheadattention"]
 keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
 
-def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
-    filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout)
+def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
+    filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
 
     return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
 

+ 3 - 1
modules/ui.py

@@ -1268,6 +1268,7 @@ def create_ui():
                     new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")
                     new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")
                     new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")
+                    new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'")
                     overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork")
 
                     with gr.Row():
@@ -1414,7 +1415,8 @@ def create_ui():
                 new_hypernetwork_activation_func,
                 new_hypernetwork_initialization_option,
                 new_hypernetwork_add_layer_norm,
-                new_hypernetwork_use_dropout
+                new_hypernetwork_use_dropout,
+                new_hypernetwork_dropout_structure
             ],
             outputs=[
                 train_hypernetwork_name,