img2img.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. import os
  2. from contextlib import closing
  3. from pathlib import Path
  4. import numpy as np
  5. from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
  6. import gradio as gr
  7. from modules import images
  8. from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters
  9. from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
  10. from modules.shared import opts, state
  11. from modules.sd_models import get_closet_checkpoint_match
  12. import modules.shared as shared
  13. import modules.processing as processing
  14. from modules.ui import plaintext_to_html
  15. import modules.scripts
  16. def process_batch(p, input, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
  17. output_dir = output_dir.strip()
  18. processing.fix_seed(p)
  19. if isinstance(input, str):
  20. batch_images = list(shared.walk_files(input, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff")))
  21. else:
  22. batch_images = [os.path.abspath(x.name) for x in input]
  23. is_inpaint_batch = False
  24. if inpaint_mask_dir:
  25. inpaint_masks = shared.listfiles(inpaint_mask_dir)
  26. is_inpaint_batch = bool(inpaint_masks)
  27. if is_inpaint_batch:
  28. print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
  29. print(f"Will process {len(batch_images)} images, creating {p.n_iter * p.batch_size} new images for each.")
  30. state.job_count = len(batch_images) * p.n_iter
  31. # extract "default" params to use in case getting png info fails
  32. prompt = p.prompt
  33. negative_prompt = p.negative_prompt
  34. seed = p.seed
  35. cfg_scale = p.cfg_scale
  36. sampler_name = p.sampler_name
  37. steps = p.steps
  38. override_settings = p.override_settings
  39. sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None))
  40. batch_results = None
  41. discard_further_results = False
  42. for i, image in enumerate(batch_images):
  43. state.job = f"{i+1} out of {len(batch_images)}"
  44. if state.skipped:
  45. state.skipped = False
  46. if state.interrupted or state.stopping_generation:
  47. break
  48. try:
  49. img = images.read(image)
  50. except UnidentifiedImageError as e:
  51. print(e)
  52. continue
  53. # Use the EXIF orientation of photos taken by smartphones.
  54. img = ImageOps.exif_transpose(img)
  55. if to_scale:
  56. p.width = int(img.width * scale_by)
  57. p.height = int(img.height * scale_by)
  58. p.init_images = [img] * p.batch_size
  59. image_path = Path(image)
  60. if is_inpaint_batch:
  61. # try to find corresponding mask for an image using simple filename matching
  62. if len(inpaint_masks) == 1:
  63. mask_image_path = inpaint_masks[0]
  64. else:
  65. # try to find corresponding mask for an image using simple filename matching
  66. mask_image_dir = Path(inpaint_mask_dir)
  67. masks_found = list(mask_image_dir.glob(f"{image_path.stem}.*"))
  68. if len(masks_found) == 0:
  69. print(f"Warning: mask is not found for {image_path} in {mask_image_dir}. Skipping it.")
  70. continue
  71. # it should contain only 1 matching mask
  72. # otherwise user has many masks with the same name but different extensions
  73. mask_image_path = masks_found[0]
  74. mask_image = images.read(mask_image_path)
  75. p.image_mask = mask_image
  76. if use_png_info:
  77. try:
  78. info_img = img
  79. if png_info_dir:
  80. info_img_path = os.path.join(png_info_dir, os.path.basename(image))
  81. info_img = images.read(info_img_path)
  82. geninfo, _ = images.read_info_from_image(info_img)
  83. parsed_parameters = parse_generation_parameters(geninfo)
  84. parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})}
  85. except Exception:
  86. parsed_parameters = {}
  87. p.prompt = prompt + (" " + parsed_parameters["Prompt"] if "Prompt" in parsed_parameters else "")
  88. p.negative_prompt = negative_prompt + (" " + parsed_parameters["Negative prompt"] if "Negative prompt" in parsed_parameters else "")
  89. p.seed = int(parsed_parameters.get("Seed", seed))
  90. p.cfg_scale = float(parsed_parameters.get("CFG scale", cfg_scale))
  91. p.sampler_name = parsed_parameters.get("Sampler", sampler_name)
  92. p.steps = int(parsed_parameters.get("Steps", steps))
  93. model_info = get_closet_checkpoint_match(parsed_parameters.get("Model hash", None))
  94. if model_info is not None:
  95. p.override_settings['sd_model_checkpoint'] = model_info.name
  96. elif sd_model_checkpoint_override:
  97. p.override_settings['sd_model_checkpoint'] = sd_model_checkpoint_override
  98. else:
  99. p.override_settings.pop("sd_model_checkpoint", None)
  100. if output_dir:
  101. p.outpath_samples = output_dir
  102. p.override_settings['save_to_dirs'] = False
  103. p.override_settings['save_images_replace_action'] = "Add number suffix"
  104. if p.n_iter > 1 or p.batch_size > 1:
  105. p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]'
  106. else:
  107. p.override_settings['samples_filename_pattern'] = f'{image_path.stem}'
  108. proc = modules.scripts.scripts_img2img.run(p, *args)
  109. if proc is None:
  110. p.override_settings.pop('save_images_replace_action', None)
  111. proc = process_images(p)
  112. if not discard_further_results and proc:
  113. if batch_results:
  114. batch_results.images.extend(proc.images)
  115. batch_results.infotexts.extend(proc.infotexts)
  116. else:
  117. batch_results = proc
  118. if 0 <= shared.opts.img2img_batch_show_results_limit < len(batch_results.images):
  119. discard_further_results = True
  120. batch_results.images = batch_results.images[:int(shared.opts.img2img_batch_show_results_limit)]
  121. batch_results.infotexts = batch_results.infotexts[:int(shared.opts.img2img_batch_show_results_limit)]
  122. return batch_results
  123. def img2img(id_task: str, request: gr.Request, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, img2img_batch_source_type: str, img2img_batch_upload: list, *args):
  124. override_settings = create_override_settings_dict(override_settings_texts)
  125. is_batch = mode == 5
  126. if mode == 0: # img2img
  127. image = init_img
  128. mask = None
  129. elif mode == 1: # img2img sketch
  130. image = sketch
  131. mask = None
  132. elif mode == 2: # inpaint
  133. image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
  134. mask = processing.create_binary_mask(mask)
  135. elif mode == 3: # inpaint sketch
  136. image = inpaint_color_sketch
  137. orig = inpaint_color_sketch_orig or inpaint_color_sketch
  138. pred = np.any(np.array(image) != np.array(orig), axis=-1)
  139. mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
  140. mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
  141. blur = ImageFilter.GaussianBlur(mask_blur)
  142. image = Image.composite(image.filter(blur), orig, mask.filter(blur))
  143. elif mode == 4: # inpaint upload mask
  144. image = init_img_inpaint
  145. mask = init_mask_inpaint
  146. else:
  147. image = None
  148. mask = None
  149. image = images.fix_image(image)
  150. mask = images.fix_image(mask)
  151. if selected_scale_tab == 1 and not is_batch:
  152. assert image, "Can't scale by because no image is selected"
  153. width = int(image.width * scale_by)
  154. height = int(image.height * scale_by)
  155. assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
  156. p = StableDiffusionProcessingImg2Img(
  157. sd_model=shared.sd_model,
  158. outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
  159. outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
  160. prompt=prompt,
  161. negative_prompt=negative_prompt,
  162. styles=prompt_styles,
  163. batch_size=batch_size,
  164. n_iter=n_iter,
  165. cfg_scale=cfg_scale,
  166. width=width,
  167. height=height,
  168. init_images=[image],
  169. mask=mask,
  170. mask_blur=mask_blur,
  171. inpainting_fill=inpainting_fill,
  172. resize_mode=resize_mode,
  173. denoising_strength=denoising_strength,
  174. image_cfg_scale=image_cfg_scale,
  175. inpaint_full_res=inpaint_full_res,
  176. inpaint_full_res_padding=inpaint_full_res_padding,
  177. inpainting_mask_invert=inpainting_mask_invert,
  178. override_settings=override_settings,
  179. )
  180. p.scripts = modules.scripts.scripts_img2img
  181. p.script_args = args
  182. p.user = request.username
  183. if shared.opts.enable_console_prompts:
  184. print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
  185. with closing(p):
  186. if is_batch:
  187. if img2img_batch_source_type == "upload":
  188. assert isinstance(img2img_batch_upload, list) and img2img_batch_upload
  189. output_dir = ""
  190. inpaint_mask_dir = ""
  191. png_info_dir = img2img_batch_png_info_dir if not shared.cmd_opts.hide_ui_dir_config else ""
  192. processed = process_batch(p, img2img_batch_upload, output_dir, inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=png_info_dir)
  193. else: # "from dir"
  194. assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
  195. processed = process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
  196. if processed is None:
  197. processed = Processed(p, [], p.seed, "")
  198. else:
  199. processed = modules.scripts.scripts_img2img.run(p, *args)
  200. if processed is None:
  201. processed = process_images(p)
  202. shared.total_tqdm.clear()
  203. generation_info_js = processed.js()
  204. if opts.samples_log_stdout:
  205. print(generation_info_js)
  206. if opts.do_not_show_images:
  207. processed.images = []
  208. return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")