hypernetwork.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. import datetime
  2. import glob
  3. import html
  4. import os
  5. import sys
  6. import traceback
  7. import tqdm
  8. import torch
  9. from ldm.util import default
  10. from modules import devices, shared, processing, sd_models
  11. import torch
  12. from torch import einsum
  13. from einops import rearrange, repeat
  14. import modules.textual_inversion.dataset
  15. class HypernetworkModule(torch.nn.Module):
  16. def __init__(self, dim, state_dict=None):
  17. super().__init__()
  18. self.linear1 = torch.nn.Linear(dim, dim * 2)
  19. self.linear2 = torch.nn.Linear(dim * 2, dim)
  20. if state_dict is not None:
  21. self.load_state_dict(state_dict, strict=True)
  22. else:
  23. self.linear1.weight.data.fill_(0.0001)
  24. self.linear1.bias.data.fill_(0.0001)
  25. self.linear2.weight.data.fill_(0.0001)
  26. self.linear2.bias.data.fill_(0.0001)
  27. self.to(devices.device)
  28. def forward(self, x):
  29. return x + (self.linear2(self.linear1(x)))
  30. class Hypernetwork:
  31. filename = None
  32. name = None
  33. def __init__(self, name=None):
  34. self.filename = None
  35. self.name = name
  36. self.layers = {}
  37. self.step = 0
  38. self.sd_checkpoint = None
  39. self.sd_checkpoint_name = None
  40. for size in [320, 640, 768, 1280]:
  41. self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))
  42. def weights(self):
  43. res = []
  44. for k, layers in self.layers.items():
  45. for layer in layers:
  46. layer.train()
  47. res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias]
  48. return res
  49. def save(self, filename):
  50. state_dict = {}
  51. for k, v in self.layers.items():
  52. state_dict[k] = (v[0].state_dict(), v[1].state_dict())
  53. state_dict['step'] = self.step
  54. state_dict['name'] = self.name
  55. state_dict['sd_checkpoint'] = self.sd_checkpoint
  56. state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
  57. torch.save(state_dict, filename)
  58. def load(self, filename):
  59. self.filename = filename
  60. if self.name is None:
  61. self.name = os.path.splitext(os.path.basename(filename))[0]
  62. state_dict = torch.load(filename, map_location='cpu')
  63. for size, sd in state_dict.items():
  64. if type(size) == int:
  65. self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
  66. self.name = state_dict.get('name', self.name)
  67. self.step = state_dict.get('step', 0)
  68. self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
  69. self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
  70. def load_hypernetworks(path):
  71. res = {}
  72. for filename in glob.iglob(path + '**/*.pt', recursive=True):
  73. try:
  74. hn = Hypernetwork()
  75. hn.load(filename)
  76. res[hn.name] = hn
  77. except Exception:
  78. print(f"Error loading hypernetwork {filename}", file=sys.stderr)
  79. print(traceback.format_exc(), file=sys.stderr)
  80. return res
  81. def attention_CrossAttention_forward(self, x, context=None, mask=None):
  82. h = self.heads
  83. q = self.to_q(x)
  84. context = default(context, x)
  85. hypernetwork_layers = (shared.hypernetwork.layers if shared.hypernetwork is not None else {}).get(context.shape[2], None)
  86. if hypernetwork_layers is not None:
  87. hypernetwork_k, hypernetwork_v = hypernetwork_layers
  88. self.hypernetwork_k = hypernetwork_k
  89. self.hypernetwork_v = hypernetwork_v
  90. context_k = hypernetwork_k(context)
  91. context_v = hypernetwork_v(context)
  92. else:
  93. context_k = context
  94. context_v = context
  95. k = self.to_k(context_k)
  96. v = self.to_v(context_v)
  97. q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
  98. sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
  99. if mask is not None:
  100. mask = rearrange(mask, 'b ... -> b (...)')
  101. max_neg_value = -torch.finfo(sim.dtype).max
  102. mask = repeat(mask, 'b j -> (b h) () j', h=h)
  103. sim.masked_fill_(~mask, max_neg_value)
  104. # attention, what we cannot get enough of
  105. attn = sim.softmax(dim=-1)
  106. out = einsum('b i j, b j d -> b i d', attn, v)
  107. out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
  108. return self.to_out(out)
  109. def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt):
  110. assert hypernetwork_name, 'embedding not selected'
  111. shared.hypernetwork = shared.hypernetworks[hypernetwork_name]
  112. shared.state.textinfo = "Initializing hypernetwork training..."
  113. shared.state.job_count = steps
  114. filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
  115. log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
  116. if save_hypernetwork_every > 0:
  117. hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
  118. os.makedirs(hypernetwork_dir, exist_ok=True)
  119. else:
  120. hypernetwork_dir = None
  121. if create_image_every > 0:
  122. images_dir = os.path.join(log_directory, "images")
  123. os.makedirs(images_dir, exist_ok=True)
  124. else:
  125. images_dir = None
  126. cond_model = shared.sd_model.cond_stage_model
  127. shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
  128. with torch.autocast("cuda"):
  129. ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file)
  130. hypernetwork = shared.hypernetworks[hypernetwork_name]
  131. weights = hypernetwork.weights()
  132. for weight in weights:
  133. weight.requires_grad = True
  134. optimizer = torch.optim.AdamW(weights, lr=learn_rate)
  135. losses = torch.zeros((32,))
  136. last_saved_file = "<none>"
  137. last_saved_image = "<none>"
  138. ititial_step = hypernetwork.step or 0
  139. if ititial_step > steps:
  140. return hypernetwork, filename
  141. pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
  142. for i, (x, text) in pbar:
  143. hypernetwork.step = i + ititial_step
  144. if hypernetwork.step > steps:
  145. break
  146. if shared.state.interrupted:
  147. break
  148. with torch.autocast("cuda"):
  149. c = cond_model([text])
  150. x = x.to(devices.device)
  151. loss = shared.sd_model(x.unsqueeze(0), c)[0]
  152. del x
  153. losses[hypernetwork.step % losses.shape[0]] = loss.item()
  154. optimizer.zero_grad()
  155. loss.backward()
  156. optimizer.step()
  157. pbar.set_description(f"loss: {losses.mean():.7f}")
  158. if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
  159. last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
  160. hypernetwork.save(last_saved_file)
  161. if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
  162. last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
  163. preview_text = text if preview_image_prompt == "" else preview_image_prompt
  164. p = processing.StableDiffusionProcessingTxt2Img(
  165. sd_model=shared.sd_model,
  166. prompt=preview_text,
  167. steps=20,
  168. do_not_save_grid=True,
  169. do_not_save_samples=True,
  170. )
  171. processed = processing.process_images(p)
  172. image = processed.images[0]
  173. shared.state.current_image = image
  174. image.save(last_saved_image)
  175. last_saved_image += f", prompt: {preview_text}"
  176. shared.state.job_no = hypernetwork.step
  177. shared.state.textinfo = f"""
  178. <p>
  179. Loss: {losses.mean():.7f}<br/>
  180. Step: {hypernetwork.step}<br/>
  181. Last prompt: {html.escape(text)}<br/>
  182. Last saved embedding: {html.escape(last_saved_file)}<br/>
  183. Last saved image: {html.escape(last_saved_image)}<br/>
  184. </p>
  185. """
  186. checkpoint = sd_models.select_checkpoint()
  187. hypernetwork.sd_checkpoint = checkpoint.hash
  188. hypernetwork.sd_checkpoint_name = checkpoint.model_name
  189. hypernetwork.save(filename)
  190. return hypernetwork, filename