Преглед изворни кода

Add warning when meet emb name conflicting

Choose standalone embedding (in /embeddings folder) first
Kohaku-Blueleaf пре 1 година
родитељ
комит
81e94de318
2 измењених фајлова са 81 додато и 32 уклоњено
  1. 33 0
      extensions-builtin/Lora/lora_logger.py
  2. 48 32
      extensions-builtin/Lora/networks.py

+ 33 - 0
extensions-builtin/Lora/lora_logger.py

@@ -0,0 +1,33 @@
+import sys
+import copy
+import logging
+
+
+class ColoredFormatter(logging.Formatter):
+    COLORS = {
+        "DEBUG": "\033[0;36m",  # CYAN
+        "INFO": "\033[0;32m",  # GREEN
+        "WARNING": "\033[0;33m",  # YELLOW
+        "ERROR": "\033[0;31m",  # RED
+        "CRITICAL": "\033[0;37;41m",  # WHITE ON RED
+        "RESET": "\033[0m",  # RESET COLOR
+    }
+
+    def format(self, record):
+        colored_record = copy.copy(record)
+        levelname = colored_record.levelname
+        seq = self.COLORS.get(levelname, self.COLORS["RESET"])
+        colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}"
+        return super().format(colored_record)
+
+
+logger = logging.getLogger("lora")
+logger.propagate = False
+
+
+if not logger.handlers:
+    handler = logging.StreamHandler(sys.stdout)
+    handler.setFormatter(
+        ColoredFormatter("[%(name)s]-%(levelname)s: %(message)s")
+    )
+    logger.addHandler(handler)

+ 48 - 32
extensions-builtin/Lora/networks.py

@@ -17,6 +17,8 @@ from typing import Union
 from modules import shared, devices, sd_models, errors, scripts, sd_hijack
 from modules.textual_inversion.textual_inversion import Embedding
 
+from lora_logger import logger
+
 module_types = [
     network_lora.ModuleTypeLora(),
     network_hada.ModuleTypeHada(),
@@ -206,7 +208,40 @@ def load_network(name, network_on_disk):
 
         net.modules[key] = net_module
 
-    net.bundle_embeddings = bundle_embeddings
+    embeddings = {}
+    for emb_name, data in bundle_embeddings.items():
+        # textual inversion embeddings
+        if 'string_to_param' in data:
+            param_dict = data['string_to_param']
+            param_dict = getattr(param_dict, '_parameters', param_dict)  # fix for torch 1.12.1 loading saved file from torch 1.11
+            assert len(param_dict) == 1, 'embedding file has multiple terms in it'
+            emb = next(iter(param_dict.items()))[1]
+            vec = emb.detach().to(devices.device, dtype=torch.float32)
+            shape = vec.shape[-1]
+            vectors = vec.shape[0]
+        elif type(data) == dict and 'clip_g' in data and 'clip_l' in data:  # SDXL embedding
+            vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
+            shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
+            vectors = data['clip_g'].shape[0]
+        elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
+            assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
+
+            emb = next(iter(data.values()))
+            if len(emb.shape) == 1:
+                emb = emb.unsqueeze(0)
+            vec = emb.detach().to(devices.device, dtype=torch.float32)
+            shape = vec.shape[-1]
+            vectors = vec.shape[0]
+        else:
+            raise Exception(f"Couldn't identify {emb_name} in lora: {name} as neither textual inversion embedding nor diffuser concept.")
+
+        embedding = Embedding(vec, emb_name)
+        embedding.vectors = vectors
+        embedding.shape = shape
+        embedding.loaded = None
+        embeddings[emb_name] = embedding
+
+    net.bundle_embeddings = embeddings
 
     if keys_failed_to_match:
         logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
@@ -229,8 +264,9 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
     for net in loaded_networks:
         if net.name in names:
             already_loaded[net.name] = net
-        for emb_name in net.bundle_embeddings:
-            emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
+        for emb_name, embedding in net.bundle_embeddings.items():
+            if embedding.loaded:
+                emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
 
     loaded_networks.clear()
 
@@ -273,37 +309,17 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
         net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
         loaded_networks.append(net)
 
-        for emb_name, data in net.bundle_embeddings.items():
-            # textual inversion embeddings
-            if 'string_to_param' in data:
-                param_dict = data['string_to_param']
-                param_dict = getattr(param_dict, '_parameters', param_dict)  # fix for torch 1.12.1 loading saved file from torch 1.11
-                assert len(param_dict) == 1, 'embedding file has multiple terms in it'
-                emb = next(iter(param_dict.items()))[1]
-                vec = emb.detach().to(devices.device, dtype=torch.float32)
-                shape = vec.shape[-1]
-                vectors = vec.shape[0]
-            elif type(data) == dict and 'clip_g' in data and 'clip_l' in data:  # SDXL embedding
-                vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
-                shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
-                vectors = data['clip_g'].shape[0]
-            elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
-                assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
-
-                emb = next(iter(data.values()))
-                if len(emb.shape) == 1:
-                    emb = emb.unsqueeze(0)
-                vec = emb.detach().to(devices.device, dtype=torch.float32)
-                shape = vec.shape[-1]
-                vectors = vec.shape[0]
-            else:
-                raise Exception(f"Couldn't identify {emb_name} in lora: {name} as neither textual inversion embedding nor diffuser concept.")
-
-            embedding = Embedding(vec, emb_name)
-            embedding.vectors = vectors
-            embedding.shape = shape
+        for emb_name, embedding in net.bundle_embeddings.items():
+            if embedding.loaded is None and emb_name in emb_db.word_embeddings:
+                logger.warning(
+                    f'Skip bundle embedding: "{emb_name}"'
+                    ' as it was already loaded from embeddings folder'
+                )
+                continue
 
+            embedding.loaded = False
             if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape:
+                embedding.loaded = True
                 emb_db.register_embedding(embedding, shared.sd_model)
             else:
                 emb_db.skipped_embeddings[name] = embedding