|
@@ -5,6 +5,8 @@ import math
|
|
from torch import nn
|
|
from torch import nn
|
|
from transformers import CLIPTokenizer, T5TokenizerFast
|
|
from transformers import CLIPTokenizer, T5TokenizerFast
|
|
|
|
|
|
|
|
+from modules import sd_hijack
|
|
|
|
+
|
|
|
|
|
|
#################################################################################################
|
|
#################################################################################################
|
|
### Core/Utility
|
|
### Core/Utility
|
|
@@ -110,9 +112,9 @@ class CLIPEncoder(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
class CLIPEmbeddings(torch.nn.Module):
|
|
class CLIPEmbeddings(torch.nn.Module):
|
|
- def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
|
|
|
|
|
|
+ def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None, textual_inversion_key="clip_l"):
|
|
super().__init__()
|
|
super().__init__()
|
|
- self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
|
|
|
|
|
|
+ self.token_embedding = sd_hijack.TextualInversionEmbeddings(vocab_size, embed_dim, dtype=dtype, device=device, textual_inversion_key=textual_inversion_key)
|
|
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
|
|
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
|
|
|
|
|
|
def forward(self, input_tokens):
|
|
def forward(self, input_tokens):
|
|
@@ -127,7 +129,7 @@ class CLIPTextModel_(torch.nn.Module):
|
|
intermediate_size = config_dict["intermediate_size"]
|
|
intermediate_size = config_dict["intermediate_size"]
|
|
intermediate_activation = config_dict["hidden_act"]
|
|
intermediate_activation = config_dict["hidden_act"]
|
|
super().__init__()
|
|
super().__init__()
|
|
- self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
|
|
|
|
|
|
+ self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device, textual_inversion_key=config_dict.get('textual_inversion_key', 'clip_l'))
|
|
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device)
|
|
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device)
|
|
self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
|
self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
|
|
|
|