Преглед изворни кода

suggestions and fixes from the PR

AUTOMATIC пре 2 година
родитељ
комит
3ec7b705c7

+ 1 - 1
extensions-builtin/Lora/scripts/lora_script.py

@@ -53,7 +53,7 @@ script_callbacks.on_infotext_pasted(lora.infotext_pasted)
 
 
 
 
 shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
 shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
-    "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + list(lora.available_loras)}, refresh=lora.list_available_loras),
+    "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras),
 }))
 }))
 
 
 
 

+ 1 - 5
extensions-builtin/SwinIR/swinir_model_arch.py

@@ -644,17 +644,13 @@ class SwinIR(nn.Module):
     """
     """
 
 
     def __init__(self, img_size=64, patch_size=1, in_chans=3,
     def __init__(self, img_size=64, patch_size=1, in_chans=3,
-                 embed_dim=96, depths=None, num_heads=None,
+                 embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
                  window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                  window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                  drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                  drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                  norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                  norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                  use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
                  use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
                  **kwargs):
                  **kwargs):
         super(SwinIR, self).__init__()
         super(SwinIR, self).__init__()
-
-        depths = depths or [6, 6, 6, 6]
-        num_heads = num_heads or [6, 6, 6, 6]
-
         num_in_ch = in_chans
         num_in_ch = in_chans
         num_out_ch = in_chans
         num_out_ch = in_chans
         num_feat = 64
         num_feat = 64

+ 2 - 9
extensions-builtin/SwinIR/swinir_model_arch_v2.py

@@ -74,12 +74,9 @@ class WindowAttention(nn.Module):
     """
     """
 
 
     def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
     def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
-                 pretrained_window_size=None):
+                 pretrained_window_size=(0, 0)):
 
 
         super().__init__()
         super().__init__()
-
-        pretrained_window_size = pretrained_window_size or [0, 0]
-
         self.dim = dim
         self.dim = dim
         self.window_size = window_size  # Wh, Ww
         self.window_size = window_size  # Wh, Ww
         self.pretrained_window_size = pretrained_window_size
         self.pretrained_window_size = pretrained_window_size
@@ -701,17 +698,13 @@ class Swin2SR(nn.Module):
     """
     """
 
 
     def __init__(self, img_size=64, patch_size=1, in_chans=3,
     def __init__(self, img_size=64, patch_size=1, in_chans=3,
-                 embed_dim=96, depths=None, num_heads=None,
+                 embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
                  window_size=7, mlp_ratio=4., qkv_bias=True, 
                  window_size=7, mlp_ratio=4., qkv_bias=True, 
                  drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                  drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                  norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                  norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                  use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
                  use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
                  **kwargs):
                  **kwargs):
         super(Swin2SR, self).__init__()
         super(Swin2SR, self).__init__()
-
-        depths = depths or [6, 6, 6, 6]
-        num_heads = num_heads or [6, 6, 6, 6]
-
         num_in_ch = in_chans
         num_in_ch = in_chans
         num_out_ch = in_chans
         num_out_ch = in_chans
         num_feat = 64
         num_feat = 64

+ 2 - 5
modules/codeformer/codeformer_arch.py

@@ -161,13 +161,10 @@ class Fuse_sft_block(nn.Module):
 class CodeFormer(VQAutoEncoder):
 class CodeFormer(VQAutoEncoder):
     def __init__(self, dim_embd=512, n_head=8, n_layers=9, 
     def __init__(self, dim_embd=512, n_head=8, n_layers=9, 
                 codebook_size=1024, latent_size=256,
                 codebook_size=1024, latent_size=256,
-                connect_list=None,
-                fix_modules=None):
+                connect_list=('32', '64', '128', '256'),
+                fix_modules=('quantize', 'generator')):
         super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
         super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
 
 
-        connect_list = connect_list or ['32', '64', '128', '256']
-        fix_modules = fix_modules or ['quantize', 'generator']
-
         if fix_modules is not None:
         if fix_modules is not None:
             for module in fix_modules:
             for module in fix_modules:
                 for param in getattr(self, module).parameters():
                 for param in getattr(self, module).parameters():

+ 2 - 2
modules/hypernetworks/ui.py

@@ -5,13 +5,13 @@ import modules.hypernetworks.hypernetwork
 from modules import devices, sd_hijack, shared
 from modules import devices, sd_hijack, shared
 
 
 not_available = ["hardswish", "multiheadattention"]
 not_available = ["hardswish", "multiheadattention"]
-keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available]
+keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict if x not in not_available]
 
 
 
 
 def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
 def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
     filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
     filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
 
 
-    return gr.Dropdown.update(choices=sorted(shared.hypernetworks.keys())), f"Created: {filename}", ""
+    return gr.Dropdown.update(choices=sorted(shared.hypernetworks)), f"Created: {filename}", ""
 
 
 
 
 def train_hypernetwork(*args):
 def train_hypernetwork(*args):

+ 2 - 2
modules/models/diffusion/uni_pc/uni_pc.py

@@ -275,8 +275,8 @@ def model_wrapper(
         A noise prediction model that accepts the noised data and the continuous time as the inputs.
         A noise prediction model that accepts the noised data and the continuous time as the inputs.
     """
     """
 
 
-    model_kwargs = model_kwargs or []
-    classifier_kwargs = classifier_kwargs or []
+    model_kwargs = model_kwargs or {}
+    classifier_kwargs = classifier_kwargs or {}
 
 
     def get_model_input_time(t_continuous):
     def get_model_input_time(t_continuous):
         """
         """

+ 1 - 1
modules/scripts_postprocessing.py

@@ -124,7 +124,7 @@ class ScriptPostprocessingRunner:
             script_args = args[script.args_from:script.args_to]
             script_args = args[script.args_from:script.args_to]
 
 
             process_args = {}
             process_args = {}
-            for (name, component), value in zip(script.controls.items(), script_args):  # noqa B007
+            for (name, _component), value in zip(script.controls.items(), script_args):
                 process_args[name] = value
                 process_args[name] = value
 
 
             script.process(pp, **process_args)
             script.process(pp, **process_args)

+ 1 - 1
modules/sd_hijack_clip.py

@@ -223,7 +223,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
             self.hijack.fixes = [x.fixes for x in batch_chunk]
             self.hijack.fixes = [x.fixes for x in batch_chunk]
 
 
             for fixes in self.hijack.fixes:
             for fixes in self.hijack.fixes:
-                for position, embedding in fixes:  # noqa: B007
+                for _position, embedding in fixes:
                     used_embeddings[embedding.name] = embedding
                     used_embeddings[embedding.name] = embedding
 
 
             z = self.process_tokens(tokens, multipliers)
             z = self.process_tokens(tokens, multipliers)

+ 1 - 1
modules/shared.py

@@ -381,7 +381,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
     "extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"),
     "extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"),
     "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"),
     "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"),
     "extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
     "extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
-    "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None"] + list(hypernetworks.keys())}, refresh=reload_hypernetworks),
+    "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", hypernetworks]}, refresh=reload_hypernetworks),
 }))
 }))
 
 
 options_templates.update(options_section(('ui', "User interface"), {
 options_templates.update(options_section(('ui', "User interface"), {

+ 1 - 2
modules/textual_inversion/textual_inversion.py

@@ -166,8 +166,7 @@ class EmbeddingDatabase:
         # textual inversion embeddings
         # textual inversion embeddings
         if 'string_to_param' in data:
         if 'string_to_param' in data:
             param_dict = data['string_to_param']
             param_dict = data['string_to_param']
-            if hasattr(param_dict, '_parameters'):
-                param_dict = param_dict._parameters  # fix for torch 1.12.1 loading saved file from torch 1.11
+            param_dict = getattr(param_dict, '_parameters', param_dict)  # fix for torch 1.12.1 loading saved file from torch 1.11
             assert len(param_dict) == 1, 'embedding file has multiple terms in it'
             assert len(param_dict) == 1, 'embedding file has multiple terms in it'
             emb = next(iter(param_dict.items()))[1]
             emb = next(iter(param_dict.items()))[1]
         # diffuser concepts
         # diffuser concepts

+ 2 - 2
modules/ui.py

@@ -1230,8 +1230,8 @@ def create_ui():
                         train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
                         train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
                         create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
                         create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
 
 
-                        train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=list(shared.hypernetworks.keys()))
-                        create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted(shared.hypernetworks.keys())}, "refresh_train_hypernetwork_name")
+                        train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=sorted(shared.hypernetworks))
+                        create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted(shared.hypernetworks)}, "refresh_train_hypernetwork_name")
 
 
                     with FormRow():
                     with FormRow():
                         embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
                         embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")