textual_inversion.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727
  1. import os
  2. from collections import namedtuple
  3. from contextlib import closing
  4. import torch
  5. import tqdm
  6. import html
  7. import datetime
  8. import csv
  9. import safetensors.torch
  10. import numpy as np
  11. from PIL import Image, PngImagePlugin
  12. from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes, cache
  13. import modules.textual_inversion.dataset
  14. from modules.textual_inversion.learn_schedule import LearnRateScheduler
  15. from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay
  16. from modules.textual_inversion.saving_settings import save_settings_to_file
  17. TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"])
  18. textual_inversion_templates = {}
  19. def list_textual_inversion_templates():
  20. textual_inversion_templates.clear()
  21. for root, _, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
  22. for fn in fns:
  23. path = os.path.join(root, fn)
  24. textual_inversion_templates[fn] = TextualInversionTemplate(fn, path)
  25. return textual_inversion_templates
  26. class Embedding:
  27. def __init__(self, vec, name, step=None):
  28. self.vec = vec
  29. self.name = name
  30. self.step = step
  31. self.shape = None
  32. self.vectors = 0
  33. self.cached_checksum = None
  34. self.sd_checkpoint = None
  35. self.sd_checkpoint_name = None
  36. self.optimizer_state_dict = None
  37. self.filename = None
  38. self.hash = None
  39. self.shorthash = None
  40. def save(self, filename):
  41. embedding_data = {
  42. "string_to_token": {"*": 265},
  43. "string_to_param": {"*": self.vec},
  44. "name": self.name,
  45. "step": self.step,
  46. "sd_checkpoint": self.sd_checkpoint,
  47. "sd_checkpoint_name": self.sd_checkpoint_name,
  48. }
  49. torch.save(embedding_data, filename)
  50. if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
  51. optimizer_saved_dict = {
  52. 'hash': self.checksum(),
  53. 'optimizer_state_dict': self.optimizer_state_dict,
  54. }
  55. torch.save(optimizer_saved_dict, f"{filename}.optim")
  56. def checksum(self):
  57. if self.cached_checksum is not None:
  58. return self.cached_checksum
  59. def const_hash(a):
  60. r = 0
  61. for v in a:
  62. r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
  63. return r
  64. self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
  65. return self.cached_checksum
  66. def set_hash(self, v):
  67. self.hash = v
  68. self.shorthash = self.hash[0:12]
  69. class DirWithTextualInversionEmbeddings:
  70. def __init__(self, path):
  71. self.path = path
  72. self.mtime = None
  73. def has_changed(self):
  74. if not os.path.isdir(self.path):
  75. return False
  76. mt = os.path.getmtime(self.path)
  77. if self.mtime is None or mt > self.mtime:
  78. return True
  79. def update(self):
  80. if not os.path.isdir(self.path):
  81. return
  82. self.mtime = os.path.getmtime(self.path)
  83. class EmbeddingDatabase:
  84. def __init__(self):
  85. self.ids_lookup = {}
  86. self.word_embeddings = {}
  87. self.skipped_embeddings = {}
  88. self.expected_shape = -1
  89. self.embedding_dirs = {}
  90. self.previously_displayed_embeddings = ()
  91. self.image_embedding_cache = cache.cache('image-embedding')
  92. def add_embedding_dir(self, path):
  93. self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
  94. def clear_embedding_dirs(self):
  95. self.embedding_dirs.clear()
  96. def register_embedding(self, embedding, model):
  97. return self.register_embedding_by_name(embedding, model, embedding.name)
  98. def register_embedding_by_name(self, embedding, model, name):
  99. ids = model.cond_stage_model.tokenize([name])[0]
  100. first_id = ids[0]
  101. if first_id not in self.ids_lookup:
  102. self.ids_lookup[first_id] = []
  103. if name in self.word_embeddings:
  104. # remove old one from the lookup list
  105. lookup = [x for x in self.ids_lookup[first_id] if x[1].name!=name]
  106. else:
  107. lookup = self.ids_lookup[first_id]
  108. if embedding is not None:
  109. lookup += [(ids, embedding)]
  110. self.ids_lookup[first_id] = sorted(lookup, key=lambda x: len(x[0]), reverse=True)
  111. if embedding is None:
  112. # unregister embedding with specified name
  113. if name in self.word_embeddings:
  114. del self.word_embeddings[name]
  115. if len(self.ids_lookup[first_id])==0:
  116. del self.ids_lookup[first_id]
  117. return None
  118. self.word_embeddings[name] = embedding
  119. return embedding
  120. def get_expected_shape(self):
  121. devices.torch_npu_set_device()
  122. vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
  123. return vec.shape[1]
  124. def read_embedding_from_image(self, path, name):
  125. try:
  126. ondisk_mtime = os.path.getmtime(path)
  127. if (cache_embedding := self.image_embedding_cache.get(path)) and ondisk_mtime == cache_embedding.get('mtime', 0):
  128. # cache will only be used if the file has not been modified time matches
  129. return cache_embedding.get('data', None), cache_embedding.get('name', None)
  130. embed_image = Image.open(path)
  131. if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
  132. data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
  133. name = data.get('name', name)
  134. elif data := extract_image_data_embed(embed_image):
  135. name = data.get('name', name)
  136. if data is None or shared.opts.textual_inversion_image_embedding_data_cache:
  137. # data of image embeddings only will be cached if the option textual_inversion_image_embedding_data_cache is enabled
  138. # results of images that are not embeddings will allways be cached to reduce unnecessary future disk reads
  139. self.image_embedding_cache[path] = {'data': data, 'name': None if data is None else name, 'mtime': ondisk_mtime}
  140. return data, name
  141. except Exception:
  142. errors.report(f"Error loading embedding {path}", exc_info=True)
  143. return None, None
  144. def load_from_file(self, path, filename):
  145. name, ext = os.path.splitext(filename)
  146. ext = ext.upper()
  147. if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
  148. _, second_ext = os.path.splitext(name)
  149. if second_ext.upper() == '.PREVIEW':
  150. return
  151. data, name = self.read_embedding_from_image(path, name)
  152. if data is None:
  153. return
  154. elif ext in ['.BIN', '.PT']:
  155. data = torch.load(path, map_location="cpu")
  156. elif ext in ['.SAFETENSORS']:
  157. data = safetensors.torch.load_file(path, device="cpu")
  158. else:
  159. return
  160. if data is not None:
  161. embedding = create_embedding_from_data(data, name, filename=filename, filepath=path)
  162. if self.expected_shape == -1 or self.expected_shape == embedding.shape:
  163. self.register_embedding(embedding, shared.sd_model)
  164. else:
  165. self.skipped_embeddings[name] = embedding
  166. else:
  167. print(f"Unable to load Textual inversion embedding due to data issue: '{name}'.")
  168. def load_from_dir(self, embdir):
  169. if not os.path.isdir(embdir.path):
  170. return
  171. for root, _, fns in os.walk(embdir.path, followlinks=True):
  172. for fn in fns:
  173. try:
  174. fullfn = os.path.join(root, fn)
  175. if os.stat(fullfn).st_size == 0:
  176. continue
  177. self.load_from_file(fullfn, fn)
  178. except Exception:
  179. errors.report(f"Error loading embedding {fn}", exc_info=True)
  180. continue
  181. def load_textual_inversion_embeddings(self, force_reload=False):
  182. if not force_reload:
  183. need_reload = False
  184. for embdir in self.embedding_dirs.values():
  185. if embdir.has_changed():
  186. need_reload = True
  187. break
  188. if not need_reload:
  189. return
  190. self.ids_lookup.clear()
  191. self.word_embeddings.clear()
  192. self.skipped_embeddings.clear()
  193. self.expected_shape = self.get_expected_shape()
  194. for embdir in self.embedding_dirs.values():
  195. self.load_from_dir(embdir)
  196. embdir.update()
  197. # re-sort word_embeddings because load_from_dir may not load in alphabetic order.
  198. # using a temporary copy so we don't reinitialize self.word_embeddings in case other objects have a reference to it.
  199. sorted_word_embeddings = {e.name: e for e in sorted(self.word_embeddings.values(), key=lambda e: e.name.lower())}
  200. self.word_embeddings.clear()
  201. self.word_embeddings.update(sorted_word_embeddings)
  202. displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
  203. if shared.opts.textual_inversion_print_at_load and self.previously_displayed_embeddings != displayed_embeddings:
  204. self.previously_displayed_embeddings = displayed_embeddings
  205. print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
  206. if self.skipped_embeddings:
  207. print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
  208. def find_embedding_at_position(self, tokens, offset):
  209. token = tokens[offset]
  210. possible_matches = self.ids_lookup.get(token, None)
  211. if possible_matches is None:
  212. return None, None
  213. for ids, embedding in possible_matches:
  214. if tokens[offset:offset + len(ids)] == ids:
  215. return embedding, len(ids)
  216. return None, None
  217. def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
  218. cond_model = shared.sd_model.cond_stage_model
  219. with devices.autocast():
  220. cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
  221. #cond_model expects at least some text, so we provide '*' as backup.
  222. embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token)
  223. vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
  224. #Only copy if we provided an init_text, otherwise keep vectors as zeros
  225. if init_text:
  226. for i in range(num_vectors_per_token):
  227. vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
  228. # Remove illegal characters from name.
  229. name = "".join( x for x in name if (x.isalnum() or x in "._- "))
  230. fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
  231. if not overwrite_old:
  232. assert not os.path.exists(fn), f"file {fn} already exists"
  233. embedding = Embedding(vec, name)
  234. embedding.step = 0
  235. embedding.save(fn)
  236. return fn
  237. def create_embedding_from_data(data, name, filename='unknown embedding file', filepath=None):
  238. if 'string_to_param' in data: # textual inversion embeddings
  239. param_dict = data['string_to_param']
  240. param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
  241. assert len(param_dict) == 1, 'embedding file has multiple terms in it'
  242. emb = next(iter(param_dict.items()))[1]
  243. vec = emb.detach().to(devices.device, dtype=torch.float32)
  244. shape = vec.shape[-1]
  245. vectors = vec.shape[0]
  246. elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
  247. vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
  248. shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
  249. vectors = data['clip_g'].shape[0]
  250. elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
  251. assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
  252. emb = next(iter(data.values()))
  253. if len(emb.shape) == 1:
  254. emb = emb.unsqueeze(0)
  255. vec = emb.detach().to(devices.device, dtype=torch.float32)
  256. shape = vec.shape[-1]
  257. vectors = vec.shape[0]
  258. else:
  259. raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
  260. embedding = Embedding(vec, name)
  261. embedding.step = data.get('step', None)
  262. embedding.sd_checkpoint = data.get('sd_checkpoint', None)
  263. embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
  264. embedding.vectors = vectors
  265. embedding.shape = shape
  266. if filepath:
  267. embedding.filename = filepath
  268. embedding.set_hash(hashes.sha256(filepath, "textual_inversion/" + name) or '')
  269. return embedding
  270. def write_loss(log_directory, filename, step, epoch_len, values):
  271. if shared.opts.training_write_csv_every == 0:
  272. return
  273. if step % shared.opts.training_write_csv_every != 0:
  274. return
  275. write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
  276. with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
  277. csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())])
  278. if write_csv_header:
  279. csv_writer.writeheader()
  280. epoch = (step - 1) // epoch_len
  281. epoch_step = (step - 1) % epoch_len
  282. csv_writer.writerow({
  283. "step": step,
  284. "epoch": epoch,
  285. "epoch_step": epoch_step,
  286. **values,
  287. })
  288. def tensorboard_setup(log_directory):
  289. from torch.utils.tensorboard import SummaryWriter
  290. os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
  291. return SummaryWriter(
  292. log_dir=os.path.join(log_directory, "tensorboard"),
  293. flush_secs=shared.opts.training_tensorboard_flush_every)
  294. def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epoch_num):
  295. tensorboard_add_scaler(tensorboard_writer, "Loss/train", loss, global_step)
  296. tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", loss, step)
  297. tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", learn_rate, global_step)
  298. tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)
  299. def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
  300. tensorboard_writer.add_scalar(tag=tag,
  301. scalar_value=value, global_step=step)
  302. def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
  303. # Convert a pil image to a torch tensor
  304. img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
  305. img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
  306. len(pil_image.getbands()))
  307. img_tensor = img_tensor.permute((2, 0, 1))
  308. tensorboard_writer.add_image(tag, img_tensor, global_step=step)
  309. def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
  310. assert model_name, f"{name} not selected"
  311. assert learn_rate, "Learning rate is empty or 0"
  312. assert isinstance(batch_size, int), "Batch size must be integer"
  313. assert batch_size > 0, "Batch size must be positive"
  314. assert isinstance(gradient_step, int), "Gradient accumulation step must be integer"
  315. assert gradient_step > 0, "Gradient accumulation step must be positive"
  316. assert data_root, "Dataset directory is empty"
  317. assert os.path.isdir(data_root), "Dataset directory doesn't exist"
  318. assert os.listdir(data_root), "Dataset directory is empty"
  319. assert template_filename, "Prompt template file not selected"
  320. assert template_file, f"Prompt template file {template_filename} not found"
  321. assert os.path.isfile(template_file.path), f"Prompt template file {template_filename} doesn't exist"
  322. assert steps, "Max steps is empty or 0"
  323. assert isinstance(steps, int), "Max steps must be integer"
  324. assert steps > 0, "Max steps must be positive"
  325. assert isinstance(save_model_every, int), "Save {name} must be integer"
  326. assert save_model_every >= 0, "Save {name} must be positive or 0"
  327. assert isinstance(create_image_every, int), "Create image must be integer"
  328. assert create_image_every >= 0, "Create image must be positive or 0"
  329. if save_model_every or create_image_every:
  330. assert log_directory, "Log directory is empty"
  331. def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_name, preview_cfg_scale, preview_seed, preview_width, preview_height):
  332. from modules import processing
  333. save_embedding_every = save_embedding_every or 0
  334. create_image_every = create_image_every or 0
  335. template_file = textual_inversion_templates.get(template_filename, None)
  336. validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
  337. template_file = template_file.path
  338. shared.state.job = "train-embedding"
  339. shared.state.textinfo = "Initializing textual inversion training..."
  340. shared.state.job_count = steps
  341. filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
  342. log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
  343. unload = shared.opts.unload_models_when_training
  344. if save_embedding_every > 0:
  345. embedding_dir = os.path.join(log_directory, "embeddings")
  346. os.makedirs(embedding_dir, exist_ok=True)
  347. else:
  348. embedding_dir = None
  349. if create_image_every > 0:
  350. images_dir = os.path.join(log_directory, "images")
  351. os.makedirs(images_dir, exist_ok=True)
  352. else:
  353. images_dir = None
  354. if create_image_every > 0 and save_image_with_stored_embedding:
  355. images_embeds_dir = os.path.join(log_directory, "image_embeddings")
  356. os.makedirs(images_embeds_dir, exist_ok=True)
  357. else:
  358. images_embeds_dir = None
  359. hijack = sd_hijack.model_hijack
  360. embedding = hijack.embedding_db.word_embeddings[embedding_name]
  361. checkpoint = sd_models.select_checkpoint()
  362. initial_step = embedding.step or 0
  363. if initial_step >= steps:
  364. shared.state.textinfo = "Model has already been trained beyond specified max steps"
  365. return embedding, filename
  366. scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
  367. clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
  368. torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
  369. None
  370. if clip_grad:
  371. clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
  372. # dataset loading may take a while, so input validations and early returns should be done before this
  373. shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
  374. old_parallel_processing_allowed = shared.parallel_processing_allowed
  375. tensorboard_writer = None
  376. if shared.opts.training_enable_tensorboard:
  377. try:
  378. tensorboard_writer = tensorboard_setup(log_directory)
  379. except ImportError:
  380. errors.report("Error initializing tensorboard", exc_info=True)
  381. pin_memory = shared.opts.pin_memory
  382. ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
  383. if shared.opts.save_training_settings_to_txt:
  384. save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
  385. latent_sampling_method = ds.latent_sampling_method
  386. dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
  387. if unload:
  388. shared.parallel_processing_allowed = False
  389. shared.sd_model.first_stage_model.to(devices.cpu)
  390. embedding.vec.requires_grad = True
  391. optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
  392. if shared.opts.save_optimizer_state:
  393. optimizer_state_dict = None
  394. if os.path.exists(f"{filename}.optim"):
  395. optimizer_saved_dict = torch.load(f"{filename}.optim", map_location='cpu')
  396. if embedding.checksum() == optimizer_saved_dict.get('hash', None):
  397. optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
  398. if optimizer_state_dict is not None:
  399. optimizer.load_state_dict(optimizer_state_dict)
  400. print("Loaded existing optimizer from checkpoint")
  401. else:
  402. print("No saved optimizer exists in checkpoint")
  403. scaler = torch.cuda.amp.GradScaler()
  404. batch_size = ds.batch_size
  405. gradient_step = ds.gradient_step
  406. # n steps = batch_size * gradient_step * n image processed
  407. steps_per_epoch = len(ds) // batch_size // gradient_step
  408. max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
  409. loss_step = 0
  410. _loss_step = 0 #internal
  411. last_saved_file = "<none>"
  412. last_saved_image = "<none>"
  413. forced_filename = "<none>"
  414. embedding_yet_to_be_embedded = False
  415. is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}
  416. img_c = None
  417. pbar = tqdm.tqdm(total=steps - initial_step)
  418. try:
  419. sd_hijack_checkpoint.add()
  420. for _ in range((steps-initial_step) * gradient_step):
  421. if scheduler.finished:
  422. break
  423. if shared.state.interrupted:
  424. break
  425. for j, batch in enumerate(dl):
  426. # works as a drop_last=True for gradient accumulation
  427. if j == max_steps_per_epoch:
  428. break
  429. scheduler.apply(optimizer, embedding.step)
  430. if scheduler.finished:
  431. break
  432. if shared.state.interrupted:
  433. break
  434. if clip_grad:
  435. clip_grad_sched.step(embedding.step)
  436. with devices.autocast():
  437. x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
  438. if use_weight:
  439. w = batch.weight.to(devices.device, non_blocking=pin_memory)
  440. c = shared.sd_model.cond_stage_model(batch.cond_text)
  441. if is_training_inpainting_model:
  442. if img_c is None:
  443. img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
  444. cond = {"c_concat": [img_c], "c_crossattn": [c]}
  445. else:
  446. cond = c
  447. if use_weight:
  448. loss = shared.sd_model.weighted_forward(x, cond, w)[0] / gradient_step
  449. del w
  450. else:
  451. loss = shared.sd_model.forward(x, cond)[0] / gradient_step
  452. del x
  453. _loss_step += loss.item()
  454. scaler.scale(loss).backward()
  455. # go back until we reach gradient accumulation steps
  456. if (j + 1) % gradient_step != 0:
  457. continue
  458. if clip_grad:
  459. clip_grad(embedding.vec, clip_grad_sched.learn_rate)
  460. scaler.step(optimizer)
  461. scaler.update()
  462. embedding.step += 1
  463. pbar.update()
  464. optimizer.zero_grad(set_to_none=True)
  465. loss_step = _loss_step
  466. _loss_step = 0
  467. steps_done = embedding.step + 1
  468. epoch_num = embedding.step // steps_per_epoch
  469. epoch_step = embedding.step % steps_per_epoch
  470. description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}"
  471. pbar.set_description(description)
  472. if embedding_dir is not None and steps_done % save_embedding_every == 0:
  473. # Before saving, change name to match current checkpoint.
  474. embedding_name_every = f'{embedding_name}-{steps_done}'
  475. last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
  476. save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
  477. embedding_yet_to_be_embedded = True
  478. write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
  479. "loss": f"{loss_step:.7f}",
  480. "learn_rate": scheduler.learn_rate
  481. })
  482. if images_dir is not None and steps_done % create_image_every == 0:
  483. forced_filename = f'{embedding_name}-{steps_done}'
  484. last_saved_image = os.path.join(images_dir, forced_filename)
  485. shared.sd_model.first_stage_model.to(devices.device)
  486. p = processing.StableDiffusionProcessingTxt2Img(
  487. sd_model=shared.sd_model,
  488. do_not_save_grid=True,
  489. do_not_save_samples=True,
  490. do_not_reload_embeddings=True,
  491. )
  492. if preview_from_txt2img:
  493. p.prompt = preview_prompt
  494. p.negative_prompt = preview_negative_prompt
  495. p.steps = preview_steps
  496. p.sampler_name = sd_samplers.samplers_map[preview_sampler_name.lower()]
  497. p.cfg_scale = preview_cfg_scale
  498. p.seed = preview_seed
  499. p.width = preview_width
  500. p.height = preview_height
  501. else:
  502. p.prompt = batch.cond_text[0]
  503. p.steps = 20
  504. p.width = training_width
  505. p.height = training_height
  506. preview_text = p.prompt
  507. with closing(p):
  508. processed = processing.process_images(p)
  509. image = processed.images[0] if len(processed.images) > 0 else None
  510. if unload:
  511. shared.sd_model.first_stage_model.to(devices.cpu)
  512. if image is not None:
  513. shared.state.assign_current_image(image)
  514. last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
  515. last_saved_image += f", prompt: {preview_text}"
  516. if tensorboard_writer and shared.opts.training_tensorboard_save_images:
  517. tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
  518. if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
  519. last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
  520. info = PngImagePlugin.PngInfo()
  521. data = torch.load(last_saved_file)
  522. info.add_text("sd-ti-embedding", embedding_to_b64(data))
  523. title = f"<{data.get('name', '???')}>"
  524. try:
  525. vectorSize = list(data['string_to_param'].values())[0].shape[0]
  526. except Exception:
  527. vectorSize = '?'
  528. checkpoint = sd_models.select_checkpoint()
  529. footer_left = checkpoint.model_name
  530. footer_mid = f'[{checkpoint.shorthash}]'
  531. footer_right = f'{vectorSize}v {steps_done}s'
  532. captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
  533. captioned_image = insert_image_data_embed(captioned_image, data)
  534. captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
  535. embedding_yet_to_be_embedded = False
  536. last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
  537. last_saved_image += f", prompt: {preview_text}"
  538. shared.state.job_no = embedding.step
  539. shared.state.textinfo = f"""
  540. <p>
  541. Loss: {loss_step:.7f}<br/>
  542. Step: {steps_done}<br/>
  543. Last prompt: {html.escape(batch.cond_text[0])}<br/>
  544. Last saved embedding: {html.escape(last_saved_file)}<br/>
  545. Last saved image: {html.escape(last_saved_image)}<br/>
  546. </p>
  547. """
  548. filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
  549. save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
  550. except Exception:
  551. errors.report("Error training embedding", exc_info=True)
  552. finally:
  553. pbar.leave = False
  554. pbar.close()
  555. shared.sd_model.first_stage_model.to(devices.device)
  556. shared.parallel_processing_allowed = old_parallel_processing_allowed
  557. sd_hijack_checkpoint.remove()
  558. return embedding, filename
  559. def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
  560. old_embedding_name = embedding.name
  561. old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
  562. old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
  563. old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
  564. try:
  565. embedding.sd_checkpoint = checkpoint.shorthash
  566. embedding.sd_checkpoint_name = checkpoint.model_name
  567. if remove_cached_checksum:
  568. embedding.cached_checksum = None
  569. embedding.name = embedding_name
  570. embedding.optimizer_state_dict = optimizer.state_dict()
  571. embedding.save(filename)
  572. except:
  573. embedding.sd_checkpoint = old_sd_checkpoint
  574. embedding.sd_checkpoint_name = old_sd_checkpoint_name
  575. embedding.name = old_embedding_name
  576. embedding.cached_checksum = old_cached_checksum
  577. raise