浏览代码

support m18

superhero-7 1 年之前
父节点
当前提交
702a1e1cc7
共有 4 个文件被更改,包括 244 次插入5 次删除
  1. 73 0
      configs/alt-diffusion-m18-inference.yaml
  2. 2 4
      modules/sd_hijack.py
  3. 5 1
      modules/sd_models_config.py
  4. 164 0
      modules/xlmr_m18.py

+ 73 - 0
configs/alt-diffusion-m18-inference.yaml

@@ -0,0 +1,73 @@
+model:
+  base_learning_rate: 1.0e-04
+  target: ldm.models.diffusion.ddpm.LatentDiffusion
+  params:
+    linear_start: 0.00085
+    linear_end: 0.0120
+    num_timesteps_cond: 1
+    log_every_t: 200
+    timesteps: 1000
+    first_stage_key: "jpg"
+    cond_stage_key: "txt"
+    image_size: 64
+    channels: 4
+    cond_stage_trainable: false   # Note: different from the one we trained before
+    conditioning_key: crossattn
+    monitor: val/loss_simple_ema
+    scale_factor: 0.18215
+    use_ema: False
+
+    scheduler_config: # 10000 warmup steps
+      target: ldm.lr_scheduler.LambdaLinearScheduler
+      params:
+        warm_up_steps: [ 10000 ]
+        cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
+        f_start: [ 1.e-6 ]
+        f_max: [ 1. ]
+        f_min: [ 1. ]
+
+    unet_config:
+      target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+      params:
+        image_size: 32 # unused
+        in_channels: 4
+        out_channels: 4
+        model_channels: 320
+        attention_resolutions: [ 4, 2, 1 ]
+        num_res_blocks: 2
+        channel_mult: [ 1, 2, 4, 4 ]
+        num_head_channels: 64
+        use_spatial_transformer: True
+        use_linear_in_transformer: True
+        transformer_depth: 1
+        context_dim: 1024
+        use_checkpoint: True
+        legacy: False
+
+    first_stage_config:
+      target: ldm.models.autoencoder.AutoencoderKL
+      params:
+        embed_dim: 4
+        monitor: val/rec_loss
+        ddconfig:
+          double_z: true
+          z_channels: 4
+          resolution: 256
+          in_channels: 3
+          out_ch: 3
+          ch: 128
+          ch_mult:
+          - 1
+          - 2
+          - 4
+          - 4
+          num_res_blocks: 2
+          attn_resolutions: []
+          dropout: 0.0
+        lossconfig:
+          target: torch.nn.Identity
+
+    cond_stage_config:
+      target: modules.xlmr_m18.BertSeriesModelWithTransformation
+      params:
+        name: "XLMR-Large"

+ 2 - 4
modules/sd_hijack.py

@@ -5,7 +5,7 @@ from types import MethodType
 from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
 from modules.hypernetworks import hypernetwork
 from modules.shared import cmd_opts
-from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
+from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
 
 import ldm.modules.attention
 import ldm.modules.diffusionmodules.model
@@ -208,11 +208,10 @@ class StableDiffusionModelHijack:
             else:
                 m.cond_stage_model = conditioner
 
-        if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
+        if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation or type(m.cond_stage_model) == xlmr_m18.BertSeriesModelWithTransformation:
             model_embeddings = m.cond_stage_model.roberta.embeddings
             model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
             m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
-
         elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
             model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
             model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
@@ -258,7 +257,6 @@ class StableDiffusionModelHijack:
 
             if hasattr(m, 'cond_stage_model'):
                 delattr(m, 'cond_stage_model')
-
         elif type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
             m.cond_stage_model = m.cond_stage_model.wrapped
 

+ 5 - 1
modules/sd_models_config.py

@@ -21,7 +21,7 @@ config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inf
 config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
 config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
 config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
-
+config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
 
 def is_using_v_parameterization_for_sd2(state_dict):
     """
@@ -95,7 +95,11 @@ def guess_model_config_from_state_dict(sd, filename):
         if diffusion_model_input.shape[1] == 8:
             return config_instruct_pix2pix
 
+    
+    # import pdb; pdb.set_trace()
     if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
+        if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
+            return config_alt_diffusion_m18
         return config_alt_diffusion
 
     return config_default

+ 164 - 0
modules/xlmr_m18.py

@@ -0,0 +1,164 @@
+from transformers import BertPreTrainedModel,BertModel,BertConfig
+import torch.nn as nn
+import torch
+from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
+from transformers import XLMRobertaModel,XLMRobertaTokenizer
+from typing import Optional
+
+class BertSeriesConfig(BertConfig):
+    def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
+
+        super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
+        self.project_dim = project_dim
+        self.pooler_fn = pooler_fn
+        self.learn_encoder = learn_encoder
+
+class RobertaSeriesConfig(XLMRobertaConfig):
+    def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
+        super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+        self.project_dim = project_dim
+        self.pooler_fn = pooler_fn
+        self.learn_encoder = learn_encoder
+
+
+class BertSeriesModelWithTransformation(BertPreTrainedModel):
+
+    _keys_to_ignore_on_load_unexpected = [r"pooler"]
+    _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+    config_class = BertSeriesConfig
+
+    def __init__(self, config=None, **kargs):
+        # modify initialization for autoloading 
+        if config is None:
+            config = XLMRobertaConfig()
+            config.attention_probs_dropout_prob= 0.1
+            config.bos_token_id=0
+            config.eos_token_id=2
+            config.hidden_act='gelu'
+            config.hidden_dropout_prob=0.1
+            config.hidden_size=1024
+            config.initializer_range=0.02
+            config.intermediate_size=4096
+            config.layer_norm_eps=1e-05
+            config.max_position_embeddings=514
+
+            config.num_attention_heads=16
+            config.num_hidden_layers=24
+            config.output_past=True
+            config.pad_token_id=1
+            config.position_embedding_type= "absolute"
+
+            config.type_vocab_size= 1
+            config.use_cache=True
+            config.vocab_size= 250002
+            config.project_dim = 1024
+            config.learn_encoder = False
+        super().__init__(config)
+        self.roberta = XLMRobertaModel(config)
+        self.transformation = nn.Linear(config.hidden_size,config.project_dim)
+        # self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
+        # self.pooler = lambda x: x[:,0]
+        # self.post_init()
+
+        self.has_pre_transformation = True
+        if self.has_pre_transformation:
+            self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim)
+            self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.post_init()
+
+    def encode(self,c):
+        device = next(self.parameters()).device
+        text = self.tokenizer(c,
+                        truncation=True,
+                        max_length=77,
+                        return_length=False,
+                        return_overflowing_tokens=False,
+                        padding="max_length",
+                        return_tensors="pt")
+        text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
+        text["attention_mask"] = torch.tensor(
+            text['attention_mask']).to(device)
+        features = self(**text)
+        return features['projection_state'] 
+
+    def forward(
+        self,
+        input_ids: Optional[torch.Tensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        token_type_ids: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.Tensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+        encoder_hidden_states: Optional[torch.Tensor] = None,
+        encoder_attention_mask: Optional[torch.Tensor] = None,
+        output_attentions: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+    ) :
+        r"""
+        """
+
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+
+        outputs = self.roberta(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            token_type_ids=token_type_ids,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            inputs_embeds=inputs_embeds,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_attention_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=True,
+            return_dict=return_dict,
+        )
+
+        # # last module outputs
+        # sequence_output = outputs[0]
+
+
+        # # project every module
+        # sequence_output_ln = self.pre_LN(sequence_output)
+
+        # # pooler
+        # pooler_output = self.pooler(sequence_output_ln)
+        # pooler_output = self.transformation(pooler_output)
+        # projection_state = self.transformation(outputs.last_hidden_state)
+
+        if self.has_pre_transformation:
+            sequence_output2 = outputs["hidden_states"][-2]
+            sequence_output2 = self.pre_LN(sequence_output2)
+            projection_state2 = self.transformation_pre(sequence_output2)
+
+            return {
+                "projection_state": projection_state2,
+                "last_hidden_state": outputs.last_hidden_state,
+                "hidden_states": outputs.hidden_states,
+                "attentions": outputs.attentions,
+            }
+        else:
+            projection_state = self.transformation(outputs.last_hidden_state)
+            return {
+                "projection_state": projection_state,
+                "last_hidden_state": outputs.last_hidden_state,
+                "hidden_states": outputs.hidden_states,
+                "attentions": outputs.attentions,
+            }
+            
+        
+        # return {
+        #     'pooler_output':pooler_output,
+        #     'last_hidden_state':outputs.last_hidden_state,
+        #     'hidden_states':outputs.hidden_states,
+        #     'attentions':outputs.attentions,
+        #     'projection_state':projection_state,
+        #     'sequence_out': sequence_output
+        # }
+
+
+class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
+    base_model_prefix = 'roberta'
+    config_class= RobertaSeriesConfig