textual_inversion.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674
  1. import os
  2. from collections import namedtuple
  3. import torch
  4. import tqdm
  5. import html
  6. import datetime
  7. import csv
  8. import safetensors.torch
  9. import numpy as np
  10. from PIL import Image, PngImagePlugin
  11. from torch.utils.tensorboard import SummaryWriter
  12. from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors
  13. import modules.textual_inversion.dataset
  14. from modules.textual_inversion.learn_schedule import LearnRateScheduler
  15. from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay
  16. from modules.textual_inversion.logging import save_settings_to_file
  17. TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"])
  18. textual_inversion_templates = {}
  19. def list_textual_inversion_templates():
  20. textual_inversion_templates.clear()
  21. for root, _, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
  22. for fn in fns:
  23. path = os.path.join(root, fn)
  24. textual_inversion_templates[fn] = TextualInversionTemplate(fn, path)
  25. return textual_inversion_templates
  26. class Embedding:
  27. def __init__(self, vec, name, step=None):
  28. self.vec = vec
  29. self.name = name
  30. self.step = step
  31. self.shape = None
  32. self.vectors = 0
  33. self.cached_checksum = None
  34. self.sd_checkpoint = None
  35. self.sd_checkpoint_name = None
  36. self.optimizer_state_dict = None
  37. self.filename = None
  38. def save(self, filename):
  39. embedding_data = {
  40. "string_to_token": {"*": 265},
  41. "string_to_param": {"*": self.vec},
  42. "name": self.name,
  43. "step": self.step,
  44. "sd_checkpoint": self.sd_checkpoint,
  45. "sd_checkpoint_name": self.sd_checkpoint_name,
  46. }
  47. torch.save(embedding_data, filename)
  48. if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
  49. optimizer_saved_dict = {
  50. 'hash': self.checksum(),
  51. 'optimizer_state_dict': self.optimizer_state_dict,
  52. }
  53. torch.save(optimizer_saved_dict, f"{filename}.optim")
  54. def checksum(self):
  55. if self.cached_checksum is not None:
  56. return self.cached_checksum
  57. def const_hash(a):
  58. r = 0
  59. for v in a:
  60. r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
  61. return r
  62. self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
  63. return self.cached_checksum
  64. class DirWithTextualInversionEmbeddings:
  65. def __init__(self, path):
  66. self.path = path
  67. self.mtime = None
  68. def has_changed(self):
  69. if not os.path.isdir(self.path):
  70. return False
  71. mt = os.path.getmtime(self.path)
  72. if self.mtime is None or mt > self.mtime:
  73. return True
  74. def update(self):
  75. if not os.path.isdir(self.path):
  76. return
  77. self.mtime = os.path.getmtime(self.path)
  78. class EmbeddingDatabase:
  79. def __init__(self):
  80. self.ids_lookup = {}
  81. self.word_embeddings = {}
  82. self.skipped_embeddings = {}
  83. self.expected_shape = -1
  84. self.embedding_dirs = {}
  85. self.previously_displayed_embeddings = ()
  86. def add_embedding_dir(self, path):
  87. self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
  88. def clear_embedding_dirs(self):
  89. self.embedding_dirs.clear()
  90. def register_embedding(self, embedding, model):
  91. return self.register_embedding_by_name(embedding, model, embedding.name)
  92. def register_embedding_by_name(self, embedding, model, name):
  93. ids = model.cond_stage_model.tokenize([name])[0]
  94. first_id = ids[0]
  95. if first_id not in self.ids_lookup:
  96. self.ids_lookup[first_id] = []
  97. if name in self.word_embeddings:
  98. # remove old one from the lookup list
  99. lookup = [x for x in self.ids_lookup[first_id] if x[1].name!=name]
  100. else:
  101. lookup = self.ids_lookup[first_id]
  102. if embedding is not None:
  103. lookup += [(ids, embedding)]
  104. self.ids_lookup[first_id] = sorted(lookup, key=lambda x: len(x[0]), reverse=True)
  105. if embedding is None:
  106. # unregister embedding with specified name
  107. if name in self.word_embeddings:
  108. del self.word_embeddings[name]
  109. if len(self.ids_lookup[first_id])==0:
  110. del self.ids_lookup[first_id]
  111. return None
  112. self.word_embeddings[name] = embedding
  113. return embedding
  114. def get_expected_shape(self):
  115. vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
  116. return vec.shape[1]
  117. def load_from_file(self, path, filename):
  118. name, ext = os.path.splitext(filename)
  119. ext = ext.upper()
  120. if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
  121. _, second_ext = os.path.splitext(name)
  122. if second_ext.upper() == '.PREVIEW':
  123. return
  124. embed_image = Image.open(path)
  125. if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
  126. data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
  127. name = data.get('name', name)
  128. else:
  129. data = extract_image_data_embed(embed_image)
  130. if data:
  131. name = data.get('name', name)
  132. else:
  133. # if data is None, means this is not an embeding, just a preview image
  134. return
  135. elif ext in ['.BIN', '.PT']:
  136. data = torch.load(path, map_location="cpu")
  137. elif ext in ['.SAFETENSORS']:
  138. data = safetensors.torch.load_file(path, device="cpu")
  139. else:
  140. return
  141. # textual inversion embeddings
  142. if 'string_to_param' in data:
  143. param_dict = data['string_to_param']
  144. param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
  145. assert len(param_dict) == 1, 'embedding file has multiple terms in it'
  146. emb = next(iter(param_dict.items()))[1]
  147. # diffuser concepts
  148. elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
  149. assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
  150. emb = next(iter(data.values()))
  151. if len(emb.shape) == 1:
  152. emb = emb.unsqueeze(0)
  153. else:
  154. raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
  155. vec = emb.detach().to(devices.device, dtype=torch.float32)
  156. embedding = Embedding(vec, name)
  157. embedding.step = data.get('step', None)
  158. embedding.sd_checkpoint = data.get('sd_checkpoint', None)
  159. embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
  160. embedding.vectors = vec.shape[0]
  161. embedding.shape = vec.shape[-1]
  162. embedding.filename = path
  163. if self.expected_shape == -1 or self.expected_shape == embedding.shape:
  164. self.register_embedding(embedding, shared.sd_model)
  165. else:
  166. self.skipped_embeddings[name] = embedding
  167. def load_from_dir(self, embdir):
  168. if not os.path.isdir(embdir.path):
  169. return
  170. for root, _, fns in os.walk(embdir.path, followlinks=True):
  171. for fn in fns:
  172. try:
  173. fullfn = os.path.join(root, fn)
  174. if os.stat(fullfn).st_size == 0:
  175. continue
  176. self.load_from_file(fullfn, fn)
  177. except Exception:
  178. errors.report(f"Error loading embedding {fn}", exc_info=True)
  179. continue
  180. def load_textual_inversion_embeddings(self, force_reload=False):
  181. if not force_reload:
  182. need_reload = False
  183. for embdir in self.embedding_dirs.values():
  184. if embdir.has_changed():
  185. need_reload = True
  186. break
  187. if not need_reload:
  188. return
  189. self.ids_lookup.clear()
  190. self.word_embeddings.clear()
  191. self.skipped_embeddings.clear()
  192. self.expected_shape = self.get_expected_shape()
  193. for embdir in self.embedding_dirs.values():
  194. self.load_from_dir(embdir)
  195. embdir.update()
  196. # re-sort word_embeddings because load_from_dir may not load in alphabetic order.
  197. # using a temporary copy so we don't reinitialize self.word_embeddings in case other objects have a reference to it.
  198. sorted_word_embeddings = {e.name: e for e in sorted(self.word_embeddings.values(), key=lambda e: e.name.lower())}
  199. self.word_embeddings.clear()
  200. self.word_embeddings.update(sorted_word_embeddings)
  201. displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
  202. if self.previously_displayed_embeddings != displayed_embeddings:
  203. self.previously_displayed_embeddings = displayed_embeddings
  204. print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
  205. if self.skipped_embeddings:
  206. print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
  207. def find_embedding_at_position(self, tokens, offset):
  208. token = tokens[offset]
  209. possible_matches = self.ids_lookup.get(token, None)
  210. if possible_matches is None:
  211. return None, None
  212. for ids, embedding in possible_matches:
  213. if tokens[offset:offset + len(ids)] == ids:
  214. return embedding, len(ids)
  215. return None, None
  216. def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
  217. cond_model = shared.sd_model.cond_stage_model
  218. with devices.autocast():
  219. cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
  220. #cond_model expects at least some text, so we provide '*' as backup.
  221. embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token)
  222. vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
  223. #Only copy if we provided an init_text, otherwise keep vectors as zeros
  224. if init_text:
  225. for i in range(num_vectors_per_token):
  226. vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
  227. # Remove illegal characters from name.
  228. name = "".join( x for x in name if (x.isalnum() or x in "._- "))
  229. fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
  230. if not overwrite_old:
  231. assert not os.path.exists(fn), f"file {fn} already exists"
  232. embedding = Embedding(vec, name)
  233. embedding.step = 0
  234. embedding.save(fn)
  235. return fn
  236. def write_loss(log_directory, filename, step, epoch_len, values):
  237. if shared.opts.training_write_csv_every == 0:
  238. return
  239. if step % shared.opts.training_write_csv_every != 0:
  240. return
  241. write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
  242. with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
  243. csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())])
  244. if write_csv_header:
  245. csv_writer.writeheader()
  246. epoch = (step - 1) // epoch_len
  247. epoch_step = (step - 1) % epoch_len
  248. csv_writer.writerow({
  249. "step": step,
  250. "epoch": epoch,
  251. "epoch_step": epoch_step,
  252. **values,
  253. })
  254. def tensorboard_setup(log_directory):
  255. os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
  256. return SummaryWriter(
  257. log_dir=os.path.join(log_directory, "tensorboard"),
  258. flush_secs=shared.opts.training_tensorboard_flush_every)
  259. def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epoch_num):
  260. tensorboard_add_scaler(tensorboard_writer, "Loss/train", loss, global_step)
  261. tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", loss, step)
  262. tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", learn_rate, global_step)
  263. tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)
  264. def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
  265. tensorboard_writer.add_scalar(tag=tag,
  266. scalar_value=value, global_step=step)
  267. def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
  268. # Convert a pil image to a torch tensor
  269. img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
  270. img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
  271. len(pil_image.getbands()))
  272. img_tensor = img_tensor.permute((2, 0, 1))
  273. tensorboard_writer.add_image(tag, img_tensor, global_step=step)
  274. def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
  275. assert model_name, f"{name} not selected"
  276. assert learn_rate, "Learning rate is empty or 0"
  277. assert isinstance(batch_size, int), "Batch size must be integer"
  278. assert batch_size > 0, "Batch size must be positive"
  279. assert isinstance(gradient_step, int), "Gradient accumulation step must be integer"
  280. assert gradient_step > 0, "Gradient accumulation step must be positive"
  281. assert data_root, "Dataset directory is empty"
  282. assert os.path.isdir(data_root), "Dataset directory doesn't exist"
  283. assert os.listdir(data_root), "Dataset directory is empty"
  284. assert template_filename, "Prompt template file not selected"
  285. assert template_file, f"Prompt template file {template_filename} not found"
  286. assert os.path.isfile(template_file.path), f"Prompt template file {template_filename} doesn't exist"
  287. assert steps, "Max steps is empty or 0"
  288. assert isinstance(steps, int), "Max steps must be integer"
  289. assert steps > 0, "Max steps must be positive"
  290. assert isinstance(save_model_every, int), "Save {name} must be integer"
  291. assert save_model_every >= 0, "Save {name} must be positive or 0"
  292. assert isinstance(create_image_every, int), "Create image must be integer"
  293. assert create_image_every >= 0, "Create image must be positive or 0"
  294. if save_model_every or create_image_every:
  295. assert log_directory, "Log directory is empty"
  296. def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
  297. save_embedding_every = save_embedding_every or 0
  298. create_image_every = create_image_every or 0
  299. template_file = textual_inversion_templates.get(template_filename, None)
  300. validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
  301. template_file = template_file.path
  302. shared.state.job = "train-embedding"
  303. shared.state.textinfo = "Initializing textual inversion training..."
  304. shared.state.job_count = steps
  305. filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
  306. log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
  307. unload = shared.opts.unload_models_when_training
  308. if save_embedding_every > 0:
  309. embedding_dir = os.path.join(log_directory, "embeddings")
  310. os.makedirs(embedding_dir, exist_ok=True)
  311. else:
  312. embedding_dir = None
  313. if create_image_every > 0:
  314. images_dir = os.path.join(log_directory, "images")
  315. os.makedirs(images_dir, exist_ok=True)
  316. else:
  317. images_dir = None
  318. if create_image_every > 0 and save_image_with_stored_embedding:
  319. images_embeds_dir = os.path.join(log_directory, "image_embeddings")
  320. os.makedirs(images_embeds_dir, exist_ok=True)
  321. else:
  322. images_embeds_dir = None
  323. hijack = sd_hijack.model_hijack
  324. embedding = hijack.embedding_db.word_embeddings[embedding_name]
  325. checkpoint = sd_models.select_checkpoint()
  326. initial_step = embedding.step or 0
  327. if initial_step >= steps:
  328. shared.state.textinfo = "Model has already been trained beyond specified max steps"
  329. return embedding, filename
  330. scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
  331. clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
  332. torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
  333. None
  334. if clip_grad:
  335. clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
  336. # dataset loading may take a while, so input validations and early returns should be done before this
  337. shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
  338. old_parallel_processing_allowed = shared.parallel_processing_allowed
  339. if shared.opts.training_enable_tensorboard:
  340. tensorboard_writer = tensorboard_setup(log_directory)
  341. pin_memory = shared.opts.pin_memory
  342. ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
  343. if shared.opts.save_training_settings_to_txt:
  344. save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
  345. latent_sampling_method = ds.latent_sampling_method
  346. dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
  347. if unload:
  348. shared.parallel_processing_allowed = False
  349. shared.sd_model.first_stage_model.to(devices.cpu)
  350. embedding.vec.requires_grad = True
  351. optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
  352. if shared.opts.save_optimizer_state:
  353. optimizer_state_dict = None
  354. if os.path.exists(f"{filename}.optim"):
  355. optimizer_saved_dict = torch.load(f"{filename}.optim", map_location='cpu')
  356. if embedding.checksum() == optimizer_saved_dict.get('hash', None):
  357. optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
  358. if optimizer_state_dict is not None:
  359. optimizer.load_state_dict(optimizer_state_dict)
  360. print("Loaded existing optimizer from checkpoint")
  361. else:
  362. print("No saved optimizer exists in checkpoint")
  363. scaler = torch.cuda.amp.GradScaler()
  364. batch_size = ds.batch_size
  365. gradient_step = ds.gradient_step
  366. # n steps = batch_size * gradient_step * n image processed
  367. steps_per_epoch = len(ds) // batch_size // gradient_step
  368. max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
  369. loss_step = 0
  370. _loss_step = 0 #internal
  371. last_saved_file = "<none>"
  372. last_saved_image = "<none>"
  373. forced_filename = "<none>"
  374. embedding_yet_to_be_embedded = False
  375. is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}
  376. img_c = None
  377. pbar = tqdm.tqdm(total=steps - initial_step)
  378. try:
  379. sd_hijack_checkpoint.add()
  380. for _ in range((steps-initial_step) * gradient_step):
  381. if scheduler.finished:
  382. break
  383. if shared.state.interrupted:
  384. break
  385. for j, batch in enumerate(dl):
  386. # works as a drop_last=True for gradient accumulation
  387. if j == max_steps_per_epoch:
  388. break
  389. scheduler.apply(optimizer, embedding.step)
  390. if scheduler.finished:
  391. break
  392. if shared.state.interrupted:
  393. break
  394. if clip_grad:
  395. clip_grad_sched.step(embedding.step)
  396. with devices.autocast():
  397. x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
  398. if use_weight:
  399. w = batch.weight.to(devices.device, non_blocking=pin_memory)
  400. c = shared.sd_model.cond_stage_model(batch.cond_text)
  401. if is_training_inpainting_model:
  402. if img_c is None:
  403. img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
  404. cond = {"c_concat": [img_c], "c_crossattn": [c]}
  405. else:
  406. cond = c
  407. if use_weight:
  408. loss = shared.sd_model.weighted_forward(x, cond, w)[0] / gradient_step
  409. del w
  410. else:
  411. loss = shared.sd_model.forward(x, cond)[0] / gradient_step
  412. del x
  413. _loss_step += loss.item()
  414. scaler.scale(loss).backward()
  415. # go back until we reach gradient accumulation steps
  416. if (j + 1) % gradient_step != 0:
  417. continue
  418. if clip_grad:
  419. clip_grad(embedding.vec, clip_grad_sched.learn_rate)
  420. scaler.step(optimizer)
  421. scaler.update()
  422. embedding.step += 1
  423. pbar.update()
  424. optimizer.zero_grad(set_to_none=True)
  425. loss_step = _loss_step
  426. _loss_step = 0
  427. steps_done = embedding.step + 1
  428. epoch_num = embedding.step // steps_per_epoch
  429. epoch_step = embedding.step % steps_per_epoch
  430. description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}"
  431. pbar.set_description(description)
  432. if embedding_dir is not None and steps_done % save_embedding_every == 0:
  433. # Before saving, change name to match current checkpoint.
  434. embedding_name_every = f'{embedding_name}-{steps_done}'
  435. last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
  436. save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
  437. embedding_yet_to_be_embedded = True
  438. write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
  439. "loss": f"{loss_step:.7f}",
  440. "learn_rate": scheduler.learn_rate
  441. })
  442. if images_dir is not None and steps_done % create_image_every == 0:
  443. forced_filename = f'{embedding_name}-{steps_done}'
  444. last_saved_image = os.path.join(images_dir, forced_filename)
  445. shared.sd_model.first_stage_model.to(devices.device)
  446. p = processing.StableDiffusionProcessingTxt2Img(
  447. sd_model=shared.sd_model,
  448. do_not_save_grid=True,
  449. do_not_save_samples=True,
  450. do_not_reload_embeddings=True,
  451. )
  452. if preview_from_txt2img:
  453. p.prompt = preview_prompt
  454. p.negative_prompt = preview_negative_prompt
  455. p.steps = preview_steps
  456. p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
  457. p.cfg_scale = preview_cfg_scale
  458. p.seed = preview_seed
  459. p.width = preview_width
  460. p.height = preview_height
  461. else:
  462. p.prompt = batch.cond_text[0]
  463. p.steps = 20
  464. p.width = training_width
  465. p.height = training_height
  466. preview_text = p.prompt
  467. processed = processing.process_images(p)
  468. image = processed.images[0] if len(processed.images) > 0 else None
  469. if unload:
  470. shared.sd_model.first_stage_model.to(devices.cpu)
  471. if image is not None:
  472. shared.state.assign_current_image(image)
  473. last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
  474. last_saved_image += f", prompt: {preview_text}"
  475. if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
  476. tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
  477. if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
  478. last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
  479. info = PngImagePlugin.PngInfo()
  480. data = torch.load(last_saved_file)
  481. info.add_text("sd-ti-embedding", embedding_to_b64(data))
  482. title = f"<{data.get('name', '???')}>"
  483. try:
  484. vectorSize = list(data['string_to_param'].values())[0].shape[0]
  485. except Exception:
  486. vectorSize = '?'
  487. checkpoint = sd_models.select_checkpoint()
  488. footer_left = checkpoint.model_name
  489. footer_mid = f'[{checkpoint.shorthash}]'
  490. footer_right = f'{vectorSize}v {steps_done}s'
  491. captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
  492. captioned_image = insert_image_data_embed(captioned_image, data)
  493. captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
  494. embedding_yet_to_be_embedded = False
  495. last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
  496. last_saved_image += f", prompt: {preview_text}"
  497. shared.state.job_no = embedding.step
  498. shared.state.textinfo = f"""
  499. <p>
  500. Loss: {loss_step:.7f}<br/>
  501. Step: {steps_done}<br/>
  502. Last prompt: {html.escape(batch.cond_text[0])}<br/>
  503. Last saved embedding: {html.escape(last_saved_file)}<br/>
  504. Last saved image: {html.escape(last_saved_image)}<br/>
  505. </p>
  506. """
  507. filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
  508. save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
  509. except Exception:
  510. errors.report("Error training embedding", exc_info=True)
  511. finally:
  512. pbar.leave = False
  513. pbar.close()
  514. shared.sd_model.first_stage_model.to(devices.device)
  515. shared.parallel_processing_allowed = old_parallel_processing_allowed
  516. sd_hijack_checkpoint.remove()
  517. return embedding, filename
  518. def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
  519. old_embedding_name = embedding.name
  520. old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
  521. old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
  522. old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
  523. try:
  524. embedding.sd_checkpoint = checkpoint.shorthash
  525. embedding.sd_checkpoint_name = checkpoint.model_name
  526. if remove_cached_checksum:
  527. embedding.cached_checksum = None
  528. embedding.name = embedding_name
  529. embedding.optimizer_state_dict = optimizer.state_dict()
  530. embedding.save(filename)
  531. except:
  532. embedding.sd_checkpoint = old_sd_checkpoint
  533. embedding.sd_checkpoint_name = old_sd_checkpoint_name
  534. embedding.name = old_embedding_name
  535. embedding.cached_checksum = old_cached_checksum
  536. raise