img2img.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. import os
  2. from contextlib import closing
  3. from pathlib import Path
  4. import numpy as np
  5. from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
  6. import gradio as gr
  7. from modules import sd_samplers, images as imgutil
  8. from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
  9. from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
  10. from modules.shared import opts, state
  11. import modules.shared as shared
  12. import modules.processing as processing
  13. from modules.ui import plaintext_to_html
  14. import modules.scripts
  15. def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
  16. output_dir = output_dir.strip()
  17. processing.fix_seed(p)
  18. images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff")))
  19. is_inpaint_batch = False
  20. if inpaint_mask_dir:
  21. inpaint_masks = shared.listfiles(inpaint_mask_dir)
  22. is_inpaint_batch = bool(inpaint_masks)
  23. if is_inpaint_batch:
  24. print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
  25. print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
  26. state.job_count = len(images) * p.n_iter
  27. # extract "default" params to use in case getting png info fails
  28. prompt = p.prompt
  29. negative_prompt = p.negative_prompt
  30. seed = p.seed
  31. cfg_scale = p.cfg_scale
  32. sampler_name = p.sampler_name
  33. steps = p.steps
  34. for i, image in enumerate(images):
  35. state.job = f"{i+1} out of {len(images)}"
  36. if state.skipped:
  37. state.skipped = False
  38. if state.interrupted:
  39. break
  40. try:
  41. img = Image.open(image)
  42. except UnidentifiedImageError as e:
  43. print(e)
  44. continue
  45. # Use the EXIF orientation of photos taken by smartphones.
  46. img = ImageOps.exif_transpose(img)
  47. if to_scale:
  48. p.width = int(img.width * scale_by)
  49. p.height = int(img.height * scale_by)
  50. p.init_images = [img] * p.batch_size
  51. image_path = Path(image)
  52. if is_inpaint_batch:
  53. # try to find corresponding mask for an image using simple filename matching
  54. if len(inpaint_masks) == 1:
  55. mask_image_path = inpaint_masks[0]
  56. else:
  57. # try to find corresponding mask for an image using simple filename matching
  58. mask_image_dir = Path(inpaint_mask_dir)
  59. masks_found = list(mask_image_dir.glob(f"{image_path.stem}.*"))
  60. if len(masks_found) == 0:
  61. print(f"Warning: mask is not found for {image_path} in {mask_image_dir}. Skipping it.")
  62. continue
  63. # it should contain only 1 matching mask
  64. # otherwise user has many masks with the same name but different extensions
  65. mask_image_path = masks_found[0]
  66. mask_image = Image.open(mask_image_path)
  67. p.image_mask = mask_image
  68. if use_png_info:
  69. try:
  70. info_img = img
  71. if png_info_dir:
  72. info_img_path = os.path.join(png_info_dir, os.path.basename(image))
  73. info_img = Image.open(info_img_path)
  74. geninfo, _ = imgutil.read_info_from_image(info_img)
  75. parsed_parameters = parse_generation_parameters(geninfo)
  76. parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})}
  77. except Exception:
  78. parsed_parameters = {}
  79. p.prompt = prompt + (" " + parsed_parameters["Prompt"] if "Prompt" in parsed_parameters else "")
  80. p.negative_prompt = negative_prompt + (" " + parsed_parameters["Negative prompt"] if "Negative prompt" in parsed_parameters else "")
  81. p.seed = int(parsed_parameters.get("Seed", seed))
  82. p.cfg_scale = float(parsed_parameters.get("CFG scale", cfg_scale))
  83. p.sampler_name = parsed_parameters.get("Sampler", sampler_name)
  84. p.steps = int(parsed_parameters.get("Steps", steps))
  85. proc = modules.scripts.scripts_img2img.run(p, *args)
  86. if proc is None:
  87. if output_dir:
  88. p.outpath_samples = output_dir
  89. p.override_settings['save_to_dirs'] = False
  90. if p.n_iter > 1 or p.batch_size > 1:
  91. p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]'
  92. else:
  93. p.override_settings['samples_filename_pattern'] = f'{image_path.stem}'
  94. process_images(p)
  95. def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
  96. override_settings = create_override_settings_dict(override_settings_texts)
  97. is_batch = mode == 5
  98. if mode == 0: # img2img
  99. image = init_img.convert("RGB")
  100. mask = None
  101. elif mode == 1: # img2img sketch
  102. image = sketch.convert("RGB")
  103. mask = None
  104. elif mode == 2: # inpaint
  105. image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
  106. mask = mask.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
  107. image = image.convert("RGB")
  108. elif mode == 3: # inpaint sketch
  109. image = inpaint_color_sketch
  110. orig = inpaint_color_sketch_orig or inpaint_color_sketch
  111. pred = np.any(np.array(image) != np.array(orig), axis=-1)
  112. mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
  113. mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
  114. blur = ImageFilter.GaussianBlur(mask_blur)
  115. image = Image.composite(image.filter(blur), orig, mask.filter(blur))
  116. image = image.convert("RGB")
  117. elif mode == 4: # inpaint upload mask
  118. image = init_img_inpaint
  119. mask = init_mask_inpaint
  120. else:
  121. image = None
  122. mask = None
  123. # Use the EXIF orientation of photos taken by smartphones.
  124. if image is not None:
  125. image = ImageOps.exif_transpose(image)
  126. if selected_scale_tab == 1 and not is_batch:
  127. assert image, "Can't scale by because no image is selected"
  128. width = int(image.width * scale_by)
  129. height = int(image.height * scale_by)
  130. assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
  131. p = StableDiffusionProcessingImg2Img(
  132. sd_model=shared.sd_model,
  133. outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
  134. outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
  135. prompt=prompt,
  136. negative_prompt=negative_prompt,
  137. styles=prompt_styles,
  138. seed=seed,
  139. subseed=subseed,
  140. subseed_strength=subseed_strength,
  141. seed_resize_from_h=seed_resize_from_h,
  142. seed_resize_from_w=seed_resize_from_w,
  143. seed_enable_extras=seed_enable_extras,
  144. sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name,
  145. batch_size=batch_size,
  146. n_iter=n_iter,
  147. steps=steps,
  148. cfg_scale=cfg_scale,
  149. width=width,
  150. height=height,
  151. restore_faces=restore_faces,
  152. tiling=tiling,
  153. init_images=[image],
  154. mask=mask,
  155. mask_blur=mask_blur,
  156. inpainting_fill=inpainting_fill,
  157. resize_mode=resize_mode,
  158. denoising_strength=denoising_strength,
  159. image_cfg_scale=image_cfg_scale,
  160. inpaint_full_res=inpaint_full_res,
  161. inpaint_full_res_padding=inpaint_full_res_padding,
  162. inpainting_mask_invert=inpainting_mask_invert,
  163. override_settings=override_settings,
  164. )
  165. p.scripts = modules.scripts.scripts_img2img
  166. p.script_args = args
  167. p.user = request.username
  168. if shared.cmd_opts.enable_console_prompts:
  169. print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
  170. if mask:
  171. p.extra_generation_params["Mask blur"] = mask_blur
  172. with closing(p):
  173. if is_batch:
  174. assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
  175. process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
  176. processed = Processed(p, [], p.seed, "")
  177. else:
  178. processed = modules.scripts.scripts_img2img.run(p, *args)
  179. if processed is None:
  180. processed = process_images(p)
  181. shared.total_tqdm.clear()
  182. generation_info_js = processed.js()
  183. if opts.samples_log_stdout:
  184. print(generation_info_js)
  185. if opts.do_not_show_images:
  186. processed.images = []
  187. return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")