ldsr_model_arch.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. import os
  2. import gc
  3. import time
  4. import numpy as np
  5. import torch
  6. import torchvision
  7. from PIL import Image
  8. from einops import rearrange, repeat
  9. from omegaconf import OmegaConf
  10. import safetensors.torch
  11. from ldm.models.diffusion.ddim import DDIMSampler
  12. from ldm.util import instantiate_from_config, ismap
  13. from modules import shared, sd_hijack
  14. cached_ldsr_model: torch.nn.Module = None
  15. # Create LDSR Class
  16. class LDSR:
  17. def load_model_from_config(self, half_attention):
  18. global cached_ldsr_model
  19. if shared.opts.ldsr_cached and cached_ldsr_model is not None:
  20. print("Loading model from cache")
  21. model: torch.nn.Module = cached_ldsr_model
  22. else:
  23. print(f"Loading model from {self.modelPath}")
  24. _, extension = os.path.splitext(self.modelPath)
  25. if extension.lower() == ".safetensors":
  26. pl_sd = safetensors.torch.load_file(self.modelPath, device="cpu")
  27. else:
  28. pl_sd = torch.load(self.modelPath, map_location="cpu")
  29. sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd
  30. config = OmegaConf.load(self.yamlPath)
  31. config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"
  32. model: torch.nn.Module = instantiate_from_config(config.model)
  33. model.load_state_dict(sd, strict=False)
  34. model = model.to(shared.device)
  35. if half_attention:
  36. model = model.half()
  37. if shared.cmd_opts.opt_channelslast:
  38. model = model.to(memory_format=torch.channels_last)
  39. sd_hijack.model_hijack.hijack(model) # apply optimization
  40. model.eval()
  41. if shared.opts.ldsr_cached:
  42. cached_ldsr_model = model
  43. return {"model": model}
  44. def __init__(self, model_path, yaml_path):
  45. self.modelPath = model_path
  46. self.yamlPath = yaml_path
  47. @staticmethod
  48. def run(model, selected_path, custom_steps, eta):
  49. example = get_cond(selected_path)
  50. n_runs = 1
  51. guider = None
  52. ckwargs = None
  53. ddim_use_x0_pred = False
  54. temperature = 1.
  55. eta = eta
  56. custom_shape = None
  57. height, width = example["image"].shape[1:3]
  58. split_input = height >= 128 and width >= 128
  59. if split_input:
  60. ks = 128
  61. stride = 64
  62. vqf = 4 #
  63. model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
  64. "vqf": vqf,
  65. "patch_distributed_vq": True,
  66. "tie_braker": False,
  67. "clip_max_weight": 0.5,
  68. "clip_min_weight": 0.01,
  69. "clip_max_tie_weight": 0.5,
  70. "clip_min_tie_weight": 0.01}
  71. else:
  72. if hasattr(model, "split_input_params"):
  73. delattr(model, "split_input_params")
  74. x_t = None
  75. logs = None
  76. for _ in range(n_runs):
  77. if custom_shape is not None:
  78. x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
  79. x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
  80. logs = make_convolutional_sample(example, model,
  81. custom_steps=custom_steps,
  82. eta=eta, quantize_x0=False,
  83. custom_shape=custom_shape,
  84. temperature=temperature, noise_dropout=0.,
  85. corrector=guider, corrector_kwargs=ckwargs, x_T=x_t,
  86. ddim_use_x0_pred=ddim_use_x0_pred
  87. )
  88. return logs
  89. def super_resolution(self, image, steps=100, target_scale=2, half_attention=False):
  90. model = self.load_model_from_config(half_attention)
  91. # Run settings
  92. diffusion_steps = int(steps)
  93. eta = 1.0
  94. gc.collect()
  95. if torch.cuda.is_available:
  96. torch.cuda.empty_cache()
  97. im_og = image
  98. width_og, height_og = im_og.size
  99. # If we can adjust the max upscale size, then the 4 below should be our variable
  100. down_sample_rate = target_scale / 4
  101. wd = width_og * down_sample_rate
  102. hd = height_og * down_sample_rate
  103. width_downsampled_pre = int(np.ceil(wd))
  104. height_downsampled_pre = int(np.ceil(hd))
  105. if down_sample_rate != 1:
  106. print(
  107. f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
  108. im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
  109. else:
  110. print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
  111. # pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
  112. pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
  113. im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
  114. logs = self.run(model["model"], im_padded, diffusion_steps, eta)
  115. sample = logs["sample"]
  116. sample = sample.detach().cpu()
  117. sample = torch.clamp(sample, -1., 1.)
  118. sample = (sample + 1.) / 2. * 255
  119. sample = sample.numpy().astype(np.uint8)
  120. sample = np.transpose(sample, (0, 2, 3, 1))
  121. a = Image.fromarray(sample[0])
  122. # remove padding
  123. a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4))
  124. del model
  125. gc.collect()
  126. if torch.cuda.is_available:
  127. torch.cuda.empty_cache()
  128. return a
  129. def get_cond(selected_path):
  130. example = {}
  131. up_f = 4
  132. c = selected_path.convert('RGB')
  133. c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
  134. c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]],
  135. antialias=True)
  136. c_up = rearrange(c_up, '1 c h w -> 1 h w c')
  137. c = rearrange(c, '1 c h w -> 1 h w c')
  138. c = 2. * c - 1.
  139. c = c.to(shared.device)
  140. example["LR_image"] = c
  141. example["image"] = c_up
  142. return example
  143. @torch.no_grad()
  144. def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
  145. mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None,
  146. corrector_kwargs=None, x_t=None
  147. ):
  148. ddim = DDIMSampler(model)
  149. bs = shape[0]
  150. shape = shape[1:]
  151. print(f"Sampling with eta = {eta}; steps: {steps}")
  152. samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
  153. normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
  154. mask=mask, x0=x0, temperature=temperature, verbose=False,
  155. score_corrector=score_corrector,
  156. corrector_kwargs=corrector_kwargs, x_t=x_t)
  157. return samples, intermediates
  158. @torch.no_grad()
  159. def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
  160. corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
  161. log = {}
  162. z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
  163. return_first_stage_outputs=True,
  164. force_c_encode=not (hasattr(model, 'split_input_params')
  165. and model.cond_stage_key == 'coordinates_bbox'),
  166. return_original_cond=True)
  167. if custom_shape is not None:
  168. z = torch.randn(custom_shape)
  169. print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
  170. z0 = None
  171. log["input"] = x
  172. log["reconstruction"] = xrec
  173. if ismap(xc):
  174. log["original_conditioning"] = model.to_rgb(xc)
  175. if hasattr(model, 'cond_stage_key'):
  176. log[model.cond_stage_key] = model.to_rgb(xc)
  177. else:
  178. log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
  179. if model.cond_stage_model:
  180. log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
  181. if model.cond_stage_key == 'class_label':
  182. log[model.cond_stage_key] = xc[model.cond_stage_key]
  183. with model.ema_scope("Plotting"):
  184. t0 = time.time()
  185. sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
  186. eta=eta,
  187. quantize_x0=quantize_x0, mask=None, x0=z0,
  188. temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs,
  189. x_t=x_T)
  190. t1 = time.time()
  191. if ddim_use_x0_pred:
  192. sample = intermediates['pred_x0'][-1]
  193. x_sample = model.decode_first_stage(sample)
  194. try:
  195. x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
  196. log["sample_noquant"] = x_sample_noquant
  197. log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
  198. except Exception:
  199. pass
  200. log["sample"] = x_sample
  201. log["time"] = t1 - t0
  202. return log