textual_inversion.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647
  1. import os
  2. import sys
  3. import traceback
  4. import inspect
  5. from collections import namedtuple
  6. import torch
  7. import tqdm
  8. import html
  9. import datetime
  10. import csv
  11. import safetensors.torch
  12. import numpy as np
  13. from PIL import Image, PngImagePlugin
  14. from torch.utils.tensorboard import SummaryWriter
  15. from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint
  16. import modules.textual_inversion.dataset
  17. from modules.textual_inversion.learn_schedule import LearnRateScheduler
  18. from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay
  19. from modules.textual_inversion.logging import save_settings_to_file
  20. TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"])
  21. textual_inversion_templates = {}
  22. def list_textual_inversion_templates():
  23. textual_inversion_templates.clear()
  24. for root, dirs, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
  25. for fn in fns:
  26. path = os.path.join(root, fn)
  27. textual_inversion_templates[fn] = TextualInversionTemplate(fn, path)
  28. return textual_inversion_templates
  29. class Embedding:
  30. def __init__(self, vec, name, step=None):
  31. self.vec = vec
  32. self.name = name
  33. self.step = step
  34. self.shape = None
  35. self.vectors = 0
  36. self.cached_checksum = None
  37. self.sd_checkpoint = None
  38. self.sd_checkpoint_name = None
  39. self.optimizer_state_dict = None
  40. self.filename = None
  41. def save(self, filename):
  42. embedding_data = {
  43. "string_to_token": {"*": 265},
  44. "string_to_param": {"*": self.vec},
  45. "name": self.name,
  46. "step": self.step,
  47. "sd_checkpoint": self.sd_checkpoint,
  48. "sd_checkpoint_name": self.sd_checkpoint_name,
  49. }
  50. torch.save(embedding_data, filename)
  51. if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
  52. optimizer_saved_dict = {
  53. 'hash': self.checksum(),
  54. 'optimizer_state_dict': self.optimizer_state_dict,
  55. }
  56. torch.save(optimizer_saved_dict, filename + '.optim')
  57. def checksum(self):
  58. if self.cached_checksum is not None:
  59. return self.cached_checksum
  60. def const_hash(a):
  61. r = 0
  62. for v in a:
  63. r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
  64. return r
  65. self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
  66. return self.cached_checksum
  67. class DirWithTextualInversionEmbeddings:
  68. def __init__(self, path):
  69. self.path = path
  70. self.mtime = None
  71. def has_changed(self):
  72. if not os.path.isdir(self.path):
  73. return False
  74. mt = os.path.getmtime(self.path)
  75. if self.mtime is None or mt > self.mtime:
  76. return True
  77. def update(self):
  78. if not os.path.isdir(self.path):
  79. return
  80. self.mtime = os.path.getmtime(self.path)
  81. class EmbeddingDatabase:
  82. def __init__(self):
  83. self.ids_lookup = {}
  84. self.word_embeddings = {}
  85. self.skipped_embeddings = {}
  86. self.expected_shape = -1
  87. self.embedding_dirs = {}
  88. def add_embedding_dir(self, path):
  89. self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
  90. def clear_embedding_dirs(self):
  91. self.embedding_dirs.clear()
  92. def register_embedding(self, embedding, model):
  93. self.word_embeddings[embedding.name] = embedding
  94. ids = model.cond_stage_model.tokenize([embedding.name])[0]
  95. first_id = ids[0]
  96. if first_id not in self.ids_lookup:
  97. self.ids_lookup[first_id] = []
  98. self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
  99. return embedding
  100. def get_expected_shape(self):
  101. vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
  102. return vec.shape[1]
  103. def load_from_file(self, path, filename):
  104. name, ext = os.path.splitext(filename)
  105. ext = ext.upper()
  106. if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
  107. _, second_ext = os.path.splitext(name)
  108. if second_ext.upper() == '.PREVIEW':
  109. return
  110. embed_image = Image.open(path)
  111. if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
  112. data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
  113. name = data.get('name', name)
  114. else:
  115. data = extract_image_data_embed(embed_image)
  116. name = data.get('name', name)
  117. elif ext in ['.BIN', '.PT']:
  118. data = torch.load(path, map_location="cpu")
  119. elif ext in ['.SAFETENSORS']:
  120. data = safetensors.torch.load_file(path, device="cpu")
  121. else:
  122. return
  123. # textual inversion embeddings
  124. if 'string_to_param' in data:
  125. param_dict = data['string_to_param']
  126. if hasattr(param_dict, '_parameters'):
  127. param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
  128. assert len(param_dict) == 1, 'embedding file has multiple terms in it'
  129. emb = next(iter(param_dict.items()))[1]
  130. # diffuser concepts
  131. elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
  132. assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
  133. emb = next(iter(data.values()))
  134. if len(emb.shape) == 1:
  135. emb = emb.unsqueeze(0)
  136. else:
  137. raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
  138. vec = emb.detach().to(devices.device, dtype=torch.float32)
  139. embedding = Embedding(vec, name)
  140. embedding.step = data.get('step', None)
  141. embedding.sd_checkpoint = data.get('sd_checkpoint', None)
  142. embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
  143. embedding.vectors = vec.shape[0]
  144. embedding.shape = vec.shape[-1]
  145. embedding.filename = path
  146. if self.expected_shape == -1 or self.expected_shape == embedding.shape:
  147. self.register_embedding(embedding, shared.sd_model)
  148. else:
  149. self.skipped_embeddings[name] = embedding
  150. def load_from_dir(self, embdir):
  151. if not os.path.isdir(embdir.path):
  152. return
  153. for root, dirs, fns in os.walk(embdir.path):
  154. for fn in fns:
  155. try:
  156. fullfn = os.path.join(root, fn)
  157. if os.stat(fullfn).st_size == 0:
  158. continue
  159. self.load_from_file(fullfn, fn)
  160. except Exception:
  161. print(f"Error loading embedding {fn}:", file=sys.stderr)
  162. print(traceback.format_exc(), file=sys.stderr)
  163. continue
  164. def load_textual_inversion_embeddings(self, force_reload=False):
  165. if not force_reload:
  166. need_reload = False
  167. for path, embdir in self.embedding_dirs.items():
  168. if embdir.has_changed():
  169. need_reload = True
  170. break
  171. if not need_reload:
  172. return
  173. self.ids_lookup.clear()
  174. self.word_embeddings.clear()
  175. self.skipped_embeddings.clear()
  176. self.expected_shape = self.get_expected_shape()
  177. for path, embdir in self.embedding_dirs.items():
  178. self.load_from_dir(embdir)
  179. embdir.update()
  180. print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
  181. if len(self.skipped_embeddings) > 0:
  182. print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
  183. def find_embedding_at_position(self, tokens, offset):
  184. token = tokens[offset]
  185. possible_matches = self.ids_lookup.get(token, None)
  186. if possible_matches is None:
  187. return None, None
  188. for ids, embedding in possible_matches:
  189. if tokens[offset:offset + len(ids)] == ids:
  190. return embedding, len(ids)
  191. return None, None
  192. def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
  193. cond_model = shared.sd_model.cond_stage_model
  194. with devices.autocast():
  195. cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
  196. #cond_model expects at least some text, so we provide '*' as backup.
  197. embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token)
  198. vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
  199. #Only copy if we provided an init_text, otherwise keep vectors as zeros
  200. if init_text:
  201. for i in range(num_vectors_per_token):
  202. vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
  203. # Remove illegal characters from name.
  204. name = "".join( x for x in name if (x.isalnum() or x in "._- "))
  205. fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
  206. if not overwrite_old:
  207. assert not os.path.exists(fn), f"file {fn} already exists"
  208. embedding = Embedding(vec, name)
  209. embedding.step = 0
  210. embedding.save(fn)
  211. return fn
  212. def write_loss(log_directory, filename, step, epoch_len, values):
  213. if shared.opts.training_write_csv_every == 0:
  214. return
  215. if step % shared.opts.training_write_csv_every != 0:
  216. return
  217. write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
  218. with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
  219. csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())])
  220. if write_csv_header:
  221. csv_writer.writeheader()
  222. epoch = (step - 1) // epoch_len
  223. epoch_step = (step - 1) % epoch_len
  224. csv_writer.writerow({
  225. "step": step,
  226. "epoch": epoch,
  227. "epoch_step": epoch_step,
  228. **values,
  229. })
  230. def tensorboard_setup(log_directory):
  231. os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
  232. return SummaryWriter(
  233. log_dir=os.path.join(log_directory, "tensorboard"),
  234. flush_secs=shared.opts.training_tensorboard_flush_every)
  235. def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epoch_num):
  236. tensorboard_add_scaler(tensorboard_writer, "Loss/train", loss, global_step)
  237. tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", loss, step)
  238. tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", learn_rate, global_step)
  239. tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)
  240. def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
  241. tensorboard_writer.add_scalar(tag=tag,
  242. scalar_value=value, global_step=step)
  243. def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
  244. # Convert a pil image to a torch tensor
  245. img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
  246. img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
  247. len(pil_image.getbands()))
  248. img_tensor = img_tensor.permute((2, 0, 1))
  249. tensorboard_writer.add_image(tag, img_tensor, global_step=step)
  250. def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
  251. assert model_name, f"{name} not selected"
  252. assert learn_rate, "Learning rate is empty or 0"
  253. assert isinstance(batch_size, int), "Batch size must be integer"
  254. assert batch_size > 0, "Batch size must be positive"
  255. assert isinstance(gradient_step, int), "Gradient accumulation step must be integer"
  256. assert gradient_step > 0, "Gradient accumulation step must be positive"
  257. assert data_root, "Dataset directory is empty"
  258. assert os.path.isdir(data_root), "Dataset directory doesn't exist"
  259. assert os.listdir(data_root), "Dataset directory is empty"
  260. assert template_filename, "Prompt template file not selected"
  261. assert template_file, f"Prompt template file {template_filename} not found"
  262. assert os.path.isfile(template_file.path), f"Prompt template file {template_filename} doesn't exist"
  263. assert steps, "Max steps is empty or 0"
  264. assert isinstance(steps, int), "Max steps must be integer"
  265. assert steps > 0, "Max steps must be positive"
  266. assert isinstance(save_model_every, int), "Save {name} must be integer"
  267. assert save_model_every >= 0, "Save {name} must be positive or 0"
  268. assert isinstance(create_image_every, int), "Create image must be integer"
  269. assert create_image_every >= 0, "Create image must be positive or 0"
  270. if save_model_every or create_image_every:
  271. assert log_directory, "Log directory is empty"
  272. def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
  273. save_embedding_every = save_embedding_every or 0
  274. create_image_every = create_image_every or 0
  275. template_file = textual_inversion_templates.get(template_filename, None)
  276. validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
  277. template_file = template_file.path
  278. shared.state.job = "train-embedding"
  279. shared.state.textinfo = "Initializing textual inversion training..."
  280. shared.state.job_count = steps
  281. filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
  282. log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
  283. unload = shared.opts.unload_models_when_training
  284. if save_embedding_every > 0:
  285. embedding_dir = os.path.join(log_directory, "embeddings")
  286. os.makedirs(embedding_dir, exist_ok=True)
  287. else:
  288. embedding_dir = None
  289. if create_image_every > 0:
  290. images_dir = os.path.join(log_directory, "images")
  291. os.makedirs(images_dir, exist_ok=True)
  292. else:
  293. images_dir = None
  294. if create_image_every > 0 and save_image_with_stored_embedding:
  295. images_embeds_dir = os.path.join(log_directory, "image_embeddings")
  296. os.makedirs(images_embeds_dir, exist_ok=True)
  297. else:
  298. images_embeds_dir = None
  299. hijack = sd_hijack.model_hijack
  300. embedding = hijack.embedding_db.word_embeddings[embedding_name]
  301. checkpoint = sd_models.select_checkpoint()
  302. initial_step = embedding.step or 0
  303. if initial_step >= steps:
  304. shared.state.textinfo = "Model has already been trained beyond specified max steps"
  305. return embedding, filename
  306. scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
  307. clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
  308. torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
  309. None
  310. if clip_grad:
  311. clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
  312. # dataset loading may take a while, so input validations and early returns should be done before this
  313. shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
  314. old_parallel_processing_allowed = shared.parallel_processing_allowed
  315. if shared.opts.training_enable_tensorboard:
  316. tensorboard_writer = tensorboard_setup(log_directory)
  317. pin_memory = shared.opts.pin_memory
  318. ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
  319. if shared.opts.save_training_settings_to_txt:
  320. save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
  321. latent_sampling_method = ds.latent_sampling_method
  322. dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
  323. if unload:
  324. shared.parallel_processing_allowed = False
  325. shared.sd_model.first_stage_model.to(devices.cpu)
  326. embedding.vec.requires_grad = True
  327. optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
  328. if shared.opts.save_optimizer_state:
  329. optimizer_state_dict = None
  330. if os.path.exists(filename + '.optim'):
  331. optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
  332. if embedding.checksum() == optimizer_saved_dict.get('hash', None):
  333. optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
  334. if optimizer_state_dict is not None:
  335. optimizer.load_state_dict(optimizer_state_dict)
  336. print("Loaded existing optimizer from checkpoint")
  337. else:
  338. print("No saved optimizer exists in checkpoint")
  339. scaler = torch.cuda.amp.GradScaler()
  340. batch_size = ds.batch_size
  341. gradient_step = ds.gradient_step
  342. # n steps = batch_size * gradient_step * n image processed
  343. steps_per_epoch = len(ds) // batch_size // gradient_step
  344. max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
  345. loss_step = 0
  346. _loss_step = 0 #internal
  347. last_saved_file = "<none>"
  348. last_saved_image = "<none>"
  349. forced_filename = "<none>"
  350. embedding_yet_to_be_embedded = False
  351. is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}
  352. img_c = None
  353. pbar = tqdm.tqdm(total=steps - initial_step)
  354. try:
  355. sd_hijack_checkpoint.add()
  356. for i in range((steps-initial_step) * gradient_step):
  357. if scheduler.finished:
  358. break
  359. if shared.state.interrupted:
  360. break
  361. for j, batch in enumerate(dl):
  362. # works as a drop_last=True for gradient accumulation
  363. if j == max_steps_per_epoch:
  364. break
  365. scheduler.apply(optimizer, embedding.step)
  366. if scheduler.finished:
  367. break
  368. if shared.state.interrupted:
  369. break
  370. if clip_grad:
  371. clip_grad_sched.step(embedding.step)
  372. with devices.autocast():
  373. x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
  374. c = shared.sd_model.cond_stage_model(batch.cond_text)
  375. if is_training_inpainting_model:
  376. if img_c is None:
  377. img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
  378. cond = {"c_concat": [img_c], "c_crossattn": [c]}
  379. else:
  380. cond = c
  381. loss = shared.sd_model(x, cond)[0] / gradient_step
  382. del x
  383. _loss_step += loss.item()
  384. scaler.scale(loss).backward()
  385. # go back until we reach gradient accumulation steps
  386. if (j + 1) % gradient_step != 0:
  387. continue
  388. if clip_grad:
  389. clip_grad(embedding.vec, clip_grad_sched.learn_rate)
  390. scaler.step(optimizer)
  391. scaler.update()
  392. embedding.step += 1
  393. pbar.update()
  394. optimizer.zero_grad(set_to_none=True)
  395. loss_step = _loss_step
  396. _loss_step = 0
  397. steps_done = embedding.step + 1
  398. epoch_num = embedding.step // steps_per_epoch
  399. epoch_step = embedding.step % steps_per_epoch
  400. description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}"
  401. pbar.set_description(description)
  402. if embedding_dir is not None and steps_done % save_embedding_every == 0:
  403. # Before saving, change name to match current checkpoint.
  404. embedding_name_every = f'{embedding_name}-{steps_done}'
  405. last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
  406. save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
  407. embedding_yet_to_be_embedded = True
  408. write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
  409. "loss": f"{loss_step:.7f}",
  410. "learn_rate": scheduler.learn_rate
  411. })
  412. if images_dir is not None and steps_done % create_image_every == 0:
  413. forced_filename = f'{embedding_name}-{steps_done}'
  414. last_saved_image = os.path.join(images_dir, forced_filename)
  415. shared.sd_model.first_stage_model.to(devices.device)
  416. p = processing.StableDiffusionProcessingTxt2Img(
  417. sd_model=shared.sd_model,
  418. do_not_save_grid=True,
  419. do_not_save_samples=True,
  420. do_not_reload_embeddings=True,
  421. )
  422. if preview_from_txt2img:
  423. p.prompt = preview_prompt
  424. p.negative_prompt = preview_negative_prompt
  425. p.steps = preview_steps
  426. p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
  427. p.cfg_scale = preview_cfg_scale
  428. p.seed = preview_seed
  429. p.width = preview_width
  430. p.height = preview_height
  431. else:
  432. p.prompt = batch.cond_text[0]
  433. p.steps = 20
  434. p.width = training_width
  435. p.height = training_height
  436. preview_text = p.prompt
  437. processed = processing.process_images(p)
  438. image = processed.images[0] if len(processed.images) > 0 else None
  439. if unload:
  440. shared.sd_model.first_stage_model.to(devices.cpu)
  441. if image is not None:
  442. shared.state.assign_current_image(image)
  443. last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
  444. last_saved_image += f", prompt: {preview_text}"
  445. if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
  446. tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
  447. if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
  448. last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
  449. info = PngImagePlugin.PngInfo()
  450. data = torch.load(last_saved_file)
  451. info.add_text("sd-ti-embedding", embedding_to_b64(data))
  452. title = "<{}>".format(data.get('name', '???'))
  453. try:
  454. vectorSize = list(data['string_to_param'].values())[0].shape[0]
  455. except Exception as e:
  456. vectorSize = '?'
  457. checkpoint = sd_models.select_checkpoint()
  458. footer_left = checkpoint.model_name
  459. footer_mid = '[{}]'.format(checkpoint.shorthash)
  460. footer_right = '{}v {}s'.format(vectorSize, steps_done)
  461. captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
  462. captioned_image = insert_image_data_embed(captioned_image, data)
  463. captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
  464. embedding_yet_to_be_embedded = False
  465. last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
  466. last_saved_image += f", prompt: {preview_text}"
  467. shared.state.job_no = embedding.step
  468. shared.state.textinfo = f"""
  469. <p>
  470. Loss: {loss_step:.7f}<br/>
  471. Step: {steps_done}<br/>
  472. Last prompt: {html.escape(batch.cond_text[0])}<br/>
  473. Last saved embedding: {html.escape(last_saved_file)}<br/>
  474. Last saved image: {html.escape(last_saved_image)}<br/>
  475. </p>
  476. """
  477. filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
  478. save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
  479. except Exception:
  480. print(traceback.format_exc(), file=sys.stderr)
  481. pass
  482. finally:
  483. pbar.leave = False
  484. pbar.close()
  485. shared.sd_model.first_stage_model.to(devices.device)
  486. shared.parallel_processing_allowed = old_parallel_processing_allowed
  487. sd_hijack_checkpoint.remove()
  488. return embedding, filename
  489. def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
  490. old_embedding_name = embedding.name
  491. old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
  492. old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
  493. old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
  494. try:
  495. embedding.sd_checkpoint = checkpoint.shorthash
  496. embedding.sd_checkpoint_name = checkpoint.model_name
  497. if remove_cached_checksum:
  498. embedding.cached_checksum = None
  499. embedding.name = embedding_name
  500. embedding.optimizer_state_dict = optimizer.state_dict()
  501. embedding.save(filename)
  502. except:
  503. embedding.sd_checkpoint = old_sd_checkpoint
  504. embedding.sd_checkpoint_name = old_sd_checkpoint_name
  505. embedding.name = old_embedding_name
  506. embedding.cached_checksum = old_cached_checksum
  507. raise