sd_models_xl.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. from __future__ import annotations
  2. import torch
  3. import sgm.models.diffusion
  4. import sgm.modules.diffusionmodules.denoiser_scaling
  5. import sgm.modules.diffusionmodules.discretizer
  6. from modules import devices, shared, prompt_parser
  7. from modules import torch_utils
  8. def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
  9. for embedder in self.conditioner.embedders:
  10. embedder.ucg_rate = 0.0
  11. width = getattr(batch, 'width', 1024) or 1024
  12. height = getattr(batch, 'height', 1024) or 1024
  13. is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
  14. aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
  15. devices_args = dict(device=devices.device, dtype=devices.dtype)
  16. sdxl_conds = {
  17. "txt": batch,
  18. "original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
  19. "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1),
  20. "target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
  21. "aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1),
  22. }
  23. force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch)
  24. c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else [])
  25. return c
  26. def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
  27. sd = self.model.state_dict()
  28. diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
  29. if diffusion_model_input is not None:
  30. if diffusion_model_input.shape[1] == 9:
  31. x = torch.cat([x] + cond['c_concat'], dim=1)
  32. return self.model(x, t, cond)
  33. def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility
  34. return x
  35. sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
  36. sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
  37. sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding
  38. def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt):
  39. res = []
  40. for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]:
  41. encoded = embedder.encode_embedding_init_text(init_text, nvpt)
  42. res.append(encoded)
  43. return torch.cat(res, dim=1)
  44. def tokenize(self: sgm.modules.GeneralConditioner, texts):
  45. for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]:
  46. return embedder.tokenize(texts)
  47. raise AssertionError('no tokenizer available')
  48. def process_texts(self, texts):
  49. for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
  50. return embedder.process_texts(texts)
  51. def get_target_prompt_token_count(self, token_count):
  52. for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]:
  53. return embedder.get_target_prompt_token_count(token_count)
  54. # those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
  55. sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
  56. sgm.modules.GeneralConditioner.tokenize = tokenize
  57. sgm.modules.GeneralConditioner.process_texts = process_texts
  58. sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
  59. def extend_sdxl(model):
  60. """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
  61. dtype = torch_utils.get_param(model.model.diffusion_model).dtype
  62. model.model.diffusion_model.dtype = dtype
  63. model.model.conditioning_key = 'crossattn'
  64. model.cond_stage_key = 'txt'
  65. # model.cond_stage_model will be set in sd_hijack
  66. model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
  67. discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
  68. model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32)
  69. model.conditioner.wrapped = torch.nn.Module()
  70. sgm.modules.attention.print = shared.ldm_print
  71. sgm.modules.diffusionmodules.model.print = shared.ldm_print
  72. sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print
  73. sgm.modules.encoders.modules.print = shared.ldm_print
  74. # this gets the code to load the vanilla attention that we override
  75. sgm.modules.attention.SDP_IS_AVAILABLE = True
  76. sgm.modules.attention.XFORMERS_IS_AVAILABLE = False