img2imgalt.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. from collections import namedtuple
  2. import numpy as np
  3. from tqdm import trange
  4. import modules.scripts as scripts
  5. import gradio as gr
  6. from modules import processing, shared, sd_samplers, prompt_parser
  7. from modules.processing import Processed
  8. from modules.shared import opts, cmd_opts, state
  9. import torch
  10. import k_diffusion as K
  11. from PIL import Image
  12. from torch import autocast
  13. from einops import rearrange, repeat
  14. import re
  15. def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
  16. x = p.init_latent
  17. s_in = x.new_ones([x.shape[0]])
  18. dnw = K.external.CompVisDenoiser(shared.sd_model)
  19. sigmas = dnw.get_sigmas(steps).flip(0)
  20. shared.state.sampling_steps = steps
  21. for i in trange(1, len(sigmas)):
  22. shared.state.sampling_step += 1
  23. x_in = torch.cat([x] * 2)
  24. sigma_in = torch.cat([sigmas[i] * s_in] * 2)
  25. cond_in = torch.cat([uncond, cond])
  26. image_conditioning = torch.cat([p.image_conditioning] * 2)
  27. cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
  28. c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]
  29. t = dnw.sigma_to_t(sigma_in)
  30. eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
  31. denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)
  32. denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale
  33. d = (x - denoised) / sigmas[i]
  34. dt = sigmas[i] - sigmas[i - 1]
  35. x = x + d * dt
  36. sd_samplers.store_latent(x)
  37. # This shouldn't be necessary, but solved some VRAM issues
  38. del x_in, sigma_in, cond_in, c_out, c_in, t,
  39. del eps, denoised_uncond, denoised_cond, denoised, d, dt
  40. shared.state.nextjob()
  41. return x / x.std()
  42. Cached = namedtuple("Cached", ["noise", "cfg_scale", "steps", "latent", "original_prompt", "original_negative_prompt", "sigma_adjustment"])
  43. # Based on changes suggested by briansemrau in https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/736
  44. def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
  45. x = p.init_latent
  46. s_in = x.new_ones([x.shape[0]])
  47. dnw = K.external.CompVisDenoiser(shared.sd_model)
  48. sigmas = dnw.get_sigmas(steps).flip(0)
  49. shared.state.sampling_steps = steps
  50. for i in trange(1, len(sigmas)):
  51. shared.state.sampling_step += 1
  52. x_in = torch.cat([x] * 2)
  53. sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2)
  54. cond_in = torch.cat([uncond, cond])
  55. image_conditioning = torch.cat([p.image_conditioning] * 2)
  56. cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
  57. c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]
  58. if i == 1:
  59. t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2))
  60. else:
  61. t = dnw.sigma_to_t(sigma_in)
  62. eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
  63. denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)
  64. denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale
  65. if i == 1:
  66. d = (x - denoised) / (2 * sigmas[i])
  67. else:
  68. d = (x - denoised) / sigmas[i - 1]
  69. dt = sigmas[i] - sigmas[i - 1]
  70. x = x + d * dt
  71. sd_samplers.store_latent(x)
  72. # This shouldn't be necessary, but solved some VRAM issues
  73. del x_in, sigma_in, cond_in, c_out, c_in, t,
  74. del eps, denoised_uncond, denoised_cond, denoised, d, dt
  75. shared.state.nextjob()
  76. return x / sigmas[-1]
  77. class Script(scripts.Script):
  78. def __init__(self):
  79. self.cache = None
  80. def title(self):
  81. return "img2img alternative test"
  82. def elem_id(self, item_id):
  83. gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id
  84. gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id)
  85. return gen_elem_id
  86. def show(self, is_img2img):
  87. return is_img2img
  88. def ui(self, is_img2img):
  89. info = gr.Markdown('''
  90. * `CFG Scale` should be 2 or lower.
  91. ''')
  92. override_sampler = gr.Checkbox(label="Override `Sampling method` to Euler?(this method is built for it)", value=True, elem_id=self.elem_id("override_sampler"))
  93. override_prompt = gr.Checkbox(label="Override `prompt` to the same value as `original prompt`?(and `negative prompt`)", value=True, elem_id=self.elem_id("override_prompt"))
  94. original_prompt = gr.Textbox(label="Original prompt", lines=1, elem_id=self.elem_id("original_prompt"))
  95. original_negative_prompt = gr.Textbox(label="Original negative prompt", lines=1, elem_id=self.elem_id("original_negative_prompt"))
  96. override_steps = gr.Checkbox(label="Override `Sampling Steps` to the same value as `Decode steps`?", value=True, elem_id=self.elem_id("override_steps"))
  97. st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50, elem_id=self.elem_id("st"))
  98. override_strength = gr.Checkbox(label="Override `Denoising strength` to 1?", value=True, elem_id=self.elem_id("override_strength"))
  99. cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0, elem_id=self.elem_id("cfg"))
  100. randomness = gr.Slider(label="Randomness", minimum=0.0, maximum=1.0, step=0.01, value=0.0, elem_id=self.elem_id("randomness"))
  101. sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False, elem_id=self.elem_id("sigma_adjustment"))
  102. return [
  103. info,
  104. override_sampler,
  105. override_prompt, original_prompt, original_negative_prompt,
  106. override_steps, st,
  107. override_strength,
  108. cfg, randomness, sigma_adjustment,
  109. ]
  110. def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment):
  111. # Override
  112. if override_sampler:
  113. p.sampler_name = "Euler"
  114. if override_prompt:
  115. p.prompt = original_prompt
  116. p.negative_prompt = original_negative_prompt
  117. if override_steps:
  118. p.steps = st
  119. if override_strength:
  120. p.denoising_strength = 1.0
  121. def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
  122. lat = (p.init_latent.cpu().numpy() * 10).astype(int)
  123. same_params = self.cache is not None and self.cache.cfg_scale == cfg and self.cache.steps == st \
  124. and self.cache.original_prompt == original_prompt \
  125. and self.cache.original_negative_prompt == original_negative_prompt \
  126. and self.cache.sigma_adjustment == sigma_adjustment
  127. same_everything = same_params and self.cache.latent.shape == lat.shape and np.abs(self.cache.latent-lat).sum() < 100
  128. if same_everything:
  129. rec_noise = self.cache.noise
  130. else:
  131. shared.state.job_count += 1
  132. cond = p.sd_model.get_learned_conditioning(p.batch_size * [original_prompt])
  133. uncond = p.sd_model.get_learned_conditioning(p.batch_size * [original_negative_prompt])
  134. if sigma_adjustment:
  135. rec_noise = find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg, st)
  136. else:
  137. rec_noise = find_noise_for_image(p, cond, uncond, cfg, st)
  138. self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment)
  139. rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p)
  140. combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)
  141. sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)
  142. sigmas = sampler.model_wrap.get_sigmas(p.steps)
  143. noise_dt = combined_noise - (p.init_latent / sigmas[0])
  144. p.seed = p.seed + 1
  145. return sampler.sample_img2img(p, p.init_latent, noise_dt, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning)
  146. p.sample = sample_extra
  147. p.extra_generation_params["Decode prompt"] = original_prompt
  148. p.extra_generation_params["Decode negative prompt"] = original_negative_prompt
  149. p.extra_generation_params["Decode CFG scale"] = cfg
  150. p.extra_generation_params["Decode steps"] = st
  151. p.extra_generation_params["Randomness"] = randomness
  152. p.extra_generation_params["Sigma Adjustment"] = sigma_adjustment
  153. processed = processing.process_images(p)
  154. return processed