Explorar o código

Merge pull request #13014 from AUTOMATIC1111/thread-safe-extranetworks-list_items

thread safe extra network list_items
AUTOMATIC1111 hai 1 ano
pai
achega
337bc4a2fb

+ 5 - 2
extensions-builtin/Lora/ui_extra_networks_lora.py

@@ -17,6 +17,8 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
 
 
     def create_item(self, name, index=None, enable_filter=True):
     def create_item(self, name, index=None, enable_filter=True):
         lora_on_disk = networks.available_networks.get(name)
         lora_on_disk = networks.available_networks.get(name)
+        if lora_on_disk is None:
+            return
 
 
         path, ext = os.path.splitext(lora_on_disk.filename)
         path, ext = os.path.splitext(lora_on_disk.filename)
 
 
@@ -66,9 +68,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
         return item
         return item
 
 
     def list_items(self):
     def list_items(self):
-        for index, name in enumerate(networks.available_networks):
+        # instantiate a list to protect against concurrent modification
+        names = list(networks.available_networks)
+        for index, name in enumerate(names):
             item = self.create_item(name, index)
             item = self.create_item(name, index)
-
             if item is not None:
             if item is not None:
                 yield item
                 yield item
 
 

+ 7 - 1
modules/ui_extra_networks_checkpoints.py

@@ -17,6 +17,9 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
 
 
     def create_item(self, name, index=None, enable_filter=True):
     def create_item(self, name, index=None, enable_filter=True):
         checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
         checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
+        if checkpoint is None:
+            return
+
         path, ext = os.path.splitext(checkpoint.filename)
         path, ext = os.path.splitext(checkpoint.filename)
         return {
         return {
             "name": checkpoint.name_for_extra,
             "name": checkpoint.name_for_extra,
@@ -32,9 +35,12 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
         }
         }
 
 
     def list_items(self):
     def list_items(self):
+        # instantiate a list to protect against concurrent modification
         names = list(sd_models.checkpoints_list)
         names = list(sd_models.checkpoints_list)
         for index, name in enumerate(names):
         for index, name in enumerate(names):
-            yield self.create_item(name, index)
+            item = self.create_item(name, index)
+            if item is not None:
+                yield item
 
 
     def allowed_directories_for_previews(self):
     def allowed_directories_for_previews(self):
         return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
         return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]

+ 10 - 3
modules/ui_extra_networks_hypernets.py

@@ -13,7 +13,10 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
         shared.reload_hypernetworks()
         shared.reload_hypernetworks()
 
 
     def create_item(self, name, index=None, enable_filter=True):
     def create_item(self, name, index=None, enable_filter=True):
-        full_path = shared.hypernetworks[name]
+        full_path = shared.hypernetworks.get(name)
+        if full_path is None:
+            return
+
         path, ext = os.path.splitext(full_path)
         path, ext = os.path.splitext(full_path)
         sha256 = sha256_from_cache(full_path, f'hypernet/{name}')
         sha256 = sha256_from_cache(full_path, f'hypernet/{name}')
         shorthash = sha256[0:10] if sha256 else None
         shorthash = sha256[0:10] if sha256 else None
@@ -31,8 +34,12 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
         }
         }
 
 
     def list_items(self):
     def list_items(self):
-        for index, name in enumerate(shared.hypernetworks):
-            yield self.create_item(name, index)
+        # instantiate a list to protect against concurrent modification
+        names = list(shared.hypernetworks)
+        for index, name in enumerate(names):
+            item = self.create_item(name, index)
+            if item is not None:
+                yield item
 
 
     def allowed_directories_for_previews(self):
     def allowed_directories_for_previews(self):
         return [shared.cmd_opts.hypernetwork_dir]
         return [shared.cmd_opts.hypernetwork_dir]

+ 8 - 2
modules/ui_extra_networks_textual_inversion.py

@@ -14,6 +14,8 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
 
 
     def create_item(self, name, index=None, enable_filter=True):
     def create_item(self, name, index=None, enable_filter=True):
         embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
         embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
+        if embedding is None:
+            return
 
 
         path, ext = os.path.splitext(embedding.filename)
         path, ext = os.path.splitext(embedding.filename)
         return {
         return {
@@ -29,8 +31,12 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
         }
         }
 
 
     def list_items(self):
     def list_items(self):
-        for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings):
-            yield self.create_item(name, index)
+        # instantiate a list to protect against concurrent modification
+        names = list(sd_hijack.model_hijack.embedding_db.word_embeddings)
+        for index, name in enumerate(names):
+            item = self.create_item(name, index)
+            if item is not None:
+                yield item
 
 
     def allowed_directories_for_previews(self):
     def allowed_directories_for_previews(self):
         return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
         return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)